From 6caaabb978a4faf932c994eb66e68872f8956477 Mon Sep 17 00:00:00 2001 From: donalevans Date: Wed, 17 Dec 2025 17:04:39 -0800 Subject: [PATCH 1/5] Add support for embedding task to JinaAI service This commit adds support for the multimodal embedding task type to the JinaAi service. In order to enable this, the existing JinaAIEmbeddingsServiceSettings class has been split into two versions, one for text_embedding and one for embedding, with the common behaviour now found in the BaseJinaAIEmbeddingsServiceSettings class. The embedding task supports using models that accept multimodal inputs as well as models that only accept text inputs, so additional logic has been added to JinaAIEmbeddingsRequestEntity.toXContent() to allow the request sent to Jina to be structured appropriately based on the type of model being used. It is necessary to know whether a given list of inputs contains non-text values, both to ensure that the model being used can support multimodal inputs, and to prevent late chunking being applied, since that setting is not supported by JinaAI for multimodal inputs. To enable this, the InferenceStringGroup class now determines whether any of the InferenceString it contains are non-text values when constructed. The response format used by the embedding task differs slightly from the response format used by the text_embedding task, so changes were made to the JinaAIEmbeddingsResponseEntity class to allow the appropriate DenseEmbeddingResults implementation to be returned based on task type. In order to support per-request task settings, additional parsing logic and a new taskSettings field have been added to the EmbeddingRequest class. This should have been present when the EmbeddingRequest class was first introduced, but it was overlooked at the time. Other changes in this commit: - Consolidate transport version definitions instead of having the same transport version defined in multiple places for JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED and JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED - Add test coverage for new task type - Greatly expand and clean up existing tests for JinaAI model and service settings classes --- .../inference/EmbeddingRequest.java | 41 +- .../inference/InferenceStringGroup.java | 56 +- .../inference/ServiceSettings.java | 4 + .../jina_ai_embedding_task_added.csv | 1 + .../resources/transport/upper_bounds/9.4.csv | 2 +- .../inference/EmbeddingRequestTests.java | 79 +- .../inference/InferenceStringGroupTests.java | 17 + .../action/EmbeddingActionRequestTests.java | 48 +- .../ChunkingSettingsBuilderTests.java | 3 +- .../inference/InferenceGetServicesIT.java | 2 +- .../InferenceNamedWriteablesProvider.java | 14 +- .../ShardBulkInferenceActionFilter.java | 10 +- .../inference/services/ServiceFields.java | 5 + .../JinaAIEmbeddingsRequestManager.java | 10 +- .../services/jinaai/JinaAIService.java | 70 +- ... BaseJinaAIEmbeddingsServiceSettings.java} | 153 +- .../JinaAIEmbeddingServiceSettings.java | 65 + .../embeddings/JinaAIEmbeddingType.java | 6 +- .../embeddings/JinaAIEmbeddingsModel.java | 21 +- .../JinaAITextEmbeddingServiceSettings.java | 66 + .../request/JinaAIEmbeddingsRequest.java | 10 +- .../JinaAIEmbeddingsRequestEntity.java | 40 +- .../JinaAIEmbeddingsResponseEntity.java | 57 +- ...eEmbeddingServiceIntegrationValidator.java | 9 +- .../xpack/inference/InputTypeTests.java | 8 + .../xpack/inference/TaskTypeTests.java | 3 + .../elasticsearch/xpack/inference/Utils.java | 34 +- .../jinaai/JinaAIServiceSettingsTests.java | 8 +- .../services/jinaai/JinaAIServiceTests.java | 1796 ++++++++++------- ...eJinaAIEmbeddingsServiceSettingsTests.java | 280 +++ .../JinaAIEmbeddingServiceSettingsTests.java | 359 ++++ .../JinaAIEmbeddingsModelTests.java | 123 +- ...aAITextEmbeddingServiceSettingsTests.java} | 127 +- .../JinaAIEmbeddingsRequestEntityTests.java | 219 +- .../request/JinaAIEmbeddingsRequestTests.java | 206 +- .../JinaAIRerankServiceSettingsTests.java | 9 +- .../rerank/JinaAIRerankTaskSettingsTests.java | 4 - .../JinaAIEmbeddingsResponseEntityTests.java | 228 ++- ...ddingServiceIntegrationValidatorTests.java | 34 +- 39 files changed, 3030 insertions(+), 1197 deletions(-) create mode 100644 server/src/main/resources/transport/definitions/referable/jina_ai_embedding_task_added.csv rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/{JinaAIEmbeddingsServiceSettings.java => BaseJinaAIEmbeddingsServiceSettings.java} (62%) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettings.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettingsTests.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/{JinaAIEmbeddingsServiceSettingsTests.java => JinaAITextEmbeddingServiceSettingsTests.java} (77%) diff --git a/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java b/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java index 29bf2727562f5..66722d85ec06e 100644 --- a/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java @@ -9,6 +9,7 @@ package org.elasticsearch.inference; +import org.elasticsearch.TransportVersion; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -26,9 +27,11 @@ import java.io.IOException; import java.util.List; +import java.util.Map; import java.util.Objects; import static java.util.Collections.singletonList; +import static org.elasticsearch.inference.ModelConfigurations.TASK_SETTINGS; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -62,10 +65,17 @@ * OR *
  * "input": ["first text input", "second text input"]
- * @param inputs The list of {@link InferenceStringGroup} inputs to generate embeddings for - * @param inputType The {@link InputType} of the request + * + * @param inputs The list of {@link InferenceStringGroup} inputs to generate embeddings for + * @param inputType The {@link InputType} of the request + * @param taskSettings The map of task settings specific to this request */ -public record EmbeddingRequest(List inputs, InputType inputType) implements Writeable, ToXContentFragment { +public record EmbeddingRequest(List inputs, InputType inputType, Map taskSettings) + implements + Writeable, + ToXContentFragment { + + public static final TransportVersion JINA_AI_EMBEDDING_TASK_ADDED = TransportVersion.fromName("jina_ai_embedding_task_added"); private static final String INPUT_FIELD = "input"; private static final String INPUT_TYPE_FIELD = "input_type"; @@ -73,7 +83,7 @@ public record EmbeddingRequest(List inputs, InputType inpu @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( EmbeddingRequest.class.getSimpleName(), - args -> new EmbeddingRequest((List) args[0], (InputType) args[1]) + args -> new EmbeddingRequest((List) args[0], (InputType) args[1], (Map) args[2]) ); static { @@ -89,31 +99,48 @@ public record EmbeddingRequest(List inputs, InputType inpu new ParseField(INPUT_TYPE_FIELD), ObjectParser.ValueType.STRING ); + PARSER.declareField( + optionalConstructorArg(), + (parser, context) -> parser.mapOrdered(), + new ParseField(TASK_SETTINGS), + ObjectParser.ValueType.OBJECT + ); } public static EmbeddingRequest of(List contents) { - return new EmbeddingRequest(contents, null); + return new EmbeddingRequest(contents, null, null); } - public EmbeddingRequest(List inputs, @Nullable InputType inputType) { + public EmbeddingRequest(List inputs, @Nullable InputType inputType, @Nullable Map taskSettings) { this.inputs = inputs; this.inputType = Objects.requireNonNullElse(inputType, InputType.UNSPECIFIED); + this.taskSettings = Objects.requireNonNullElse(taskSettings, Map.of()); } public EmbeddingRequest(StreamInput in) throws IOException { - this(in.readCollectionAsImmutableList(InferenceStringGroup::new), in.readEnum(InputType.class)); + this( + in.readCollectionAsImmutableList(InferenceStringGroup::new), + in.readEnum(InputType.class), + in.getTransportVersion().supports(JINA_AI_EMBEDDING_TASK_ADDED) ? in.readGenericMap() : Map.of() + ); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeCollection(inputs); out.writeEnum(inputType); + if (out.getTransportVersion().supports(JINA_AI_EMBEDDING_TASK_ADDED)) { + out.writeGenericMap(taskSettings); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.field(INPUT_FIELD, inputs); builder.field(INPUT_TYPE_FIELD, inputType); + if (taskSettings.isEmpty() == false) { + builder.field(TASK_SETTINGS, taskSettings); + } return builder; } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java b/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java index 986503d059120..134c8082259cd 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.List; +import java.util.Objects; import static java.util.Collections.singletonList; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; @@ -37,9 +38,8 @@ * ] * } * - * @param inferenceStrings the list of {@link InferenceString} which should result in generating a single embedding vector */ -public record InferenceStringGroup(List inferenceStrings) implements Writeable, ToXContentObject { +public final class InferenceStringGroup implements Writeable, ToXContentObject { private static final String CONTENT_FIELD = "content"; @SuppressWarnings("unchecked") @@ -47,10 +47,22 @@ public record InferenceStringGroup(List inferenceStrings) imple InferenceStringGroup.class.getSimpleName(), args -> new InferenceStringGroup((List) args[0]) ); + static { PARSER.declareObjectArray(constructorArg(), InferenceString.PARSER::apply, new ParseField(CONTENT_FIELD)); } + private final List inferenceStrings; + private final boolean containsNonTextEntry; + + /** + * @param inferenceStrings the list of {@link InferenceString} which should result in generating a single embedding vector + */ + public InferenceStringGroup(List inferenceStrings) { + this.inferenceStrings = inferenceStrings; + containsNonTextEntry = inferenceStrings.stream().anyMatch(s -> s.isText() == false); + } + public InferenceStringGroup(StreamInput in) throws IOException { this(in.readCollectionAsImmutableList(InferenceString::new)); } @@ -64,6 +76,14 @@ public InferenceStringGroup(String input) { this(singletonList(new InferenceString(DataType.TEXT, input))); } + public List inferenceStrings() { + return inferenceStrings; + } + + public boolean containsNonTextEntry() { + return containsNonTextEntry; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeCollection(inferenceStrings); @@ -81,7 +101,7 @@ public static InferenceStringGroup parse(XContentParser parser) throws IOExcepti var token = parser.currentToken(); if (token == XContentParser.Token.VALUE_STRING) { // Create content object from String - return new InferenceStringGroup(singletonList(new InferenceString(DataType.TEXT, parser.text()))); + return new InferenceStringGroup(parser.text()); } else if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { // Create content object from InferenceString(s) return InferenceStringGroup.PARSER.apply(parser, null); @@ -134,4 +154,34 @@ public static List toInferenceStringList(List toStringList(List inferenceStringGroups) { return InferenceString.toStringList(toInferenceStringList(inferenceStringGroups)); } + + /** + * Method used to determine if a list of {@link InferenceStringGroup} contains any {@link InferenceString} that represent a non-text + * value + * + * @param inferenceStringGroups the list of {@link InferenceStringGroup} to check + * @return true if the input list contains any non-text values, false otherwise + */ + public static boolean containsNonTextEntry(List inferenceStringGroups) { + return inferenceStringGroups.stream().anyMatch(InferenceStringGroup::containsNonTextEntry); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) return true; + if (obj == null || obj.getClass() != this.getClass()) return false; + var that = (InferenceStringGroup) obj; + return Objects.equals(this.inferenceStrings, that.inferenceStrings); + } + + @Override + public int hashCode() { + return Objects.hash(inferenceStrings); + } + + @Override + public String toString() { + return "InferenceStringGroup[" + "inferenceStrings=" + inferenceStrings + ']'; + } + } diff --git a/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java b/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java index f97239f8df8df..81348f65e4847 100644 --- a/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java @@ -55,6 +55,10 @@ default DenseVectorFieldMapper.ElementType elementType() { return null; } + default boolean isMultimodal() { + return false; + } + /** * The model to use in the inference endpoint (e.g. text-embedding-ada-002). Sometimes the model is not defined in the service * settings. This can happen for external providers (e.g. hugging face, azure ai studio) where the provider requires that the model diff --git a/server/src/main/resources/transport/definitions/referable/jina_ai_embedding_task_added.csv b/server/src/main/resources/transport/definitions/referable/jina_ai_embedding_task_added.csv new file mode 100644 index 0000000000000..017c60f0bb209 --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/jina_ai_embedding_task_added.csv @@ -0,0 +1 @@ +9257000 diff --git a/server/src/main/resources/transport/upper_bounds/9.4.csv b/server/src/main/resources/transport/upper_bounds/9.4.csv index 794078470a587..65707aec377e3 100644 --- a/server/src/main/resources/transport/upper_bounds/9.4.csv +++ b/server/src/main/resources/transport/upper_bounds/9.4.csv @@ -1 +1 @@ -jina_ai_embedding_refactor,9256000 +jina_ai_embedding_task_added,9257000 diff --git a/server/src/test/java/org/elasticsearch/inference/EmbeddingRequestTests.java b/server/src/test/java/org/elasticsearch/inference/EmbeddingRequestTests.java index cbf78186ee142..1b120d0d19b21 100644 --- a/server/src/test/java/org/elasticsearch/inference/EmbeddingRequestTests.java +++ b/server/src/test/java/org/elasticsearch/inference/EmbeddingRequestTests.java @@ -21,7 +21,10 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Map; +import static org.elasticsearch.inference.EmbeddingRequest.JINA_AI_EMBEDDING_TASK_ADDED; +import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.is; public class EmbeddingRequestTests extends AbstractBWCSerializationTestCase { @@ -40,6 +43,7 @@ public void testParser_withSingleString() throws IOException { ); assertThat(request.inputs(), is(expectedInputs)); assertThat(request.inputType(), is(InputType.SEARCH)); + assertThat(request.taskSettings(), anEmptyMap()); } } @@ -60,6 +64,7 @@ public void testParser_withSingleContentObject() throws IOException { ); assertThat(request.inputs(), is(expectedInputs)); assertThat(request.inputType(), is(InputType.SEARCH)); + assertThat(request.taskSettings(), anEmptyMap()); } } @@ -78,6 +83,7 @@ public void testParser_withStringArray() throws IOException { ); assertThat(request.inputs(), is(expectedInputs)); assertThat(request.inputType(), is(InputType.SEARCH)); + assertThat(request.taskSettings(), anEmptyMap()); } } @@ -106,6 +112,7 @@ public void testParser_withSingleContentObjectWithMultipleEntries() throws IOExc ); assertThat(request.inputs(), is(expectedInputs)); assertThat(request.inputType(), is(InputType.SEARCH)); + assertThat(request.taskSettings(), anEmptyMap()); } } @@ -140,6 +147,7 @@ public void testParser_withMultipleContentObjects() throws IOException { ); assertThat(request.inputs(), is(expectedInputs)); assertThat(request.inputType(), is(InputType.SEARCH)); + assertThat(request.taskSettings(), anEmptyMap()); } } @@ -173,6 +181,7 @@ public void testParser_withUnspecifiedFormats_usesDefaults() throws IOException ); assertThat(request.inputs(), is(expectedInputs)); assertThat(request.inputType(), is(InputType.SEARCH)); + assertThat(request.taskSettings(), anEmptyMap()); } } @@ -189,6 +198,46 @@ public void testParser_withNoInputType() throws IOException { ); assertThat(request.inputs(), is(expectedInputs)); assertThat(request.inputType(), is(InputType.UNSPECIFIED)); + assertThat(request.taskSettings(), anEmptyMap()); + } + } + + public void testParser_withTaskSettings() throws IOException { + var requestJson = """ + { + "input": "some text input", + "task_settings": { + "field_one": "value_one", + "field_two": 123 + } + } + """; + try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { + var request = EmbeddingRequest.PARSER.apply(parser, null); + var expectedInputs = List.of( + new InferenceStringGroup(List.of(new InferenceString(DataType.TEXT, DataFormat.TEXT, "some text input"))) + ); + assertThat(request.inputs(), is(expectedInputs)); + assertThat(request.inputType(), is(InputType.UNSPECIFIED)); + assertThat(request.taskSettings(), is(Map.of("field_one", "value_one", "field_two", 123))); + } + } + + public void testParser_withEmptyTaskSettings() throws IOException { + var requestJson = """ + { + "input": "some text input", + "task_settings": {} + } + """; + try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { + var request = EmbeddingRequest.PARSER.apply(parser, null); + var expectedInputs = List.of( + new InferenceStringGroup(List.of(new InferenceString(DataType.TEXT, DataFormat.TEXT, "some text input"))) + ); + assertThat(request.inputs(), is(expectedInputs)); + assertThat(request.inputType(), is(InputType.UNSPECIFIED)); + assertThat(request.taskSettings(), anEmptyMap()); } } @@ -203,7 +252,11 @@ protected EmbeddingRequest createTestInstance() { } public static EmbeddingRequest createRandom() { - return new EmbeddingRequest(randomEmbeddingContents(), randomFrom(InputType.values())); + return new EmbeddingRequest( + randomEmbeddingContents(), + randomFrom(InputType.values()), + Map.of(randomAlphanumericOfLength(8), randomAlphanumericOfLength(8)) + ); } private static List randomEmbeddingContents() { @@ -216,21 +269,27 @@ private static List randomEmbeddingContents() { @Override protected EmbeddingRequest mutateInstance(EmbeddingRequest instance) throws IOException { - if (randomBoolean()) { - var embeddingContents = instance.inputs(); - return new EmbeddingRequest( - randomValueOtherThan(embeddingContents, EmbeddingRequestTests::randomEmbeddingContents), - instance.inputType() + var embeddingContents = instance.inputs(); + var inputType = instance.inputType(); + var taskSettings = instance.taskSettings(); + switch (randomInt(2)) { + case 0 -> embeddingContents = randomValueOtherThan(embeddingContents, EmbeddingRequestTests::randomEmbeddingContents); + case 1 -> inputType = randomValueOtherThan(inputType, () -> randomFrom(InputType.values())); + case 2 -> taskSettings = randomValueOtherThan( + taskSettings, + () -> Map.of(randomAlphanumericOfLength(8), randomAlphanumericOfLength(8)) ); - } else { - InputType inputType = instance.inputType(); - return new EmbeddingRequest(instance.inputs(), randomValueOtherThan(inputType, () -> randomFrom(InputType.values()))); } + return new EmbeddingRequest(embeddingContents, inputType, taskSettings); } @Override protected EmbeddingRequest mutateInstanceForVersion(EmbeddingRequest instance, TransportVersion version) { - return instance; + if (version.supports(JINA_AI_EMBEDDING_TASK_ADDED)) { + return instance; + } else { + return new EmbeddingRequest(instance.inputs(), instance.inputType(), Map.of()); + } } @Override diff --git a/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java b/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java index ed1f7db68b497..8ec4de041f3c6 100644 --- a/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java +++ b/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.List; +import static org.elasticsearch.inference.InferenceStringGroup.containsNonTextEntry; import static org.elasticsearch.inference.InferenceStringGroup.toInferenceStringList; import static org.elasticsearch.inference.InferenceStringGroup.toStringList; import static org.hamcrest.Matchers.contains; @@ -31,12 +32,14 @@ public void testStringConstructor() { String stringValue = "a string"; var input = new InferenceStringGroup(stringValue); assertThat(input.inferenceStrings(), contains(new InferenceString(DataType.TEXT, DataFormat.TEXT, stringValue))); + assertThat(input.containsNonTextEntry(), is(false)); } public void testSingleArgumentConstructor() { InferenceString inferenceString = new InferenceString(DataType.IMAGE, DataFormat.BASE64, "a string"); var input = new InferenceStringGroup(inferenceString); assertThat(input.inferenceStrings(), contains(inferenceString)); + assertThat(input.containsNonTextEntry(), is(true)); } public void testValue_withMoreThanOneElement_throws() { @@ -77,6 +80,20 @@ public void testToStringList_withMoreThanOneElement_throws() { assertThat(expectedException.getMessage(), is("Multiple-input InferenceStringGroup passed to InferenceStringGroup.toStringList")); } + public void testContainsNonTextEntry_withOnlyTextInputs() { + var inputs = List.of(new InferenceStringGroup("string1"), new InferenceStringGroup("string2")); + assertThat(containsNonTextEntry(inputs), is(false)); + } + + public void testContainsNonTextEntry_withNonTextInput() { + DataType nonTextDataType = randomValueOtherThan(DataType.TEXT, () -> randomFrom(DataType.values())); + var inputs = List.of( + new InferenceStringGroup("string1"), + new InferenceStringGroup(new InferenceString(nonTextDataType, "non text")) + ); + assertThat(containsNonTextEntry(inputs), is(true)); + } + @Override protected InferenceStringGroup mutateInstanceForVersion(InferenceStringGroup instance, TransportVersion version) { return instance; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/EmbeddingActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/EmbeddingActionRequestTests.java index 1de31682a3524..553fc6ea99b20 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/EmbeddingActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/EmbeddingActionRequestTests.java @@ -23,7 +23,9 @@ import java.io.IOException; import java.util.List; +import java.util.Map; +import static org.elasticsearch.inference.EmbeddingRequest.JINA_AI_EMBEDDING_TASK_ADDED; import static org.elasticsearch.xpack.core.inference.action.EmbeddingAction.Request.parseRequest; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -39,7 +41,10 @@ public void testParseRequest() throws IOException { "content": {"type": "image", "format": "base64", "value": "some image input" } } ], - "input_type": "search" + "input_type": "search", + "task_settings": { + "field": "value" + } } """; try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { @@ -53,7 +58,8 @@ public void testParseRequest() throws IOException { taskType, new EmbeddingRequest( List.of(new InferenceStringGroup(new InferenceString(DataType.IMAGE, DataFormat.BASE64, "some image input"))), - InputType.SEARCH + InputType.SEARCH, + Map.of("field", "value") ), context, timeout @@ -73,7 +79,7 @@ public void testValidate_withNullEmbeddingRequestInputs_returnsValidationExcepti var request = new EmbeddingAction.Request( randomAlphanumericOfLength(8), TaskType.EMBEDDING, - new EmbeddingRequest(null, randomFrom(InputType.values())), + new EmbeddingRequest(null, randomFrom(InputType.values()), Map.of()), new InferenceContext(randomAlphaOfLength(10)), TimeValue.timeValueMillis(randomLongBetween(1, 2048)) ); @@ -87,7 +93,7 @@ public void testValidate_withEmptyEmbeddingRequestInputs_returnsValidationExcept var request = new EmbeddingAction.Request( randomAlphanumericOfLength(8), TaskType.EMBEDDING, - new EmbeddingRequest(List.of(), randomFrom(InputType.values())), + new EmbeddingRequest(List.of(), randomFrom(InputType.values()), Map.of()), new InferenceContext(randomAlphaOfLength(10)), TimeValue.timeValueMillis(randomLongBetween(1, 2048)) ); @@ -112,10 +118,11 @@ public void testValidate_withNonEmbeddingTaskType_returnsValidationException() { } public void testValidate_withMultipleValidationErrors_returnsAll() { + // Create a request with an invalid task type and null inputs var request = new EmbeddingAction.Request( randomAlphanumericOfLength(8), randomValueOtherThanMany(TaskType.EMBEDDING::isAnyOrSame, () -> randomFrom(TaskType.values())), - new EmbeddingRequest(null, randomFrom(InputType.values())), + new EmbeddingRequest(null, randomFrom(InputType.values()), Map.of()), new InferenceContext(randomAlphaOfLength(10)), TimeValue.timeValueMillis(randomLongBetween(1, 2048)) ); @@ -128,16 +135,25 @@ public void testValidate_withMultipleValidationErrors_returnsAll() { @Override protected EmbeddingAction.Request mutateInstanceForVersion(EmbeddingAction.Request instance, TransportVersion version) { + // Use empty task settings if node is on a version before Jina AI embedding task support was added, since embedding request task + // settings were added with that change + var embeddingRequest = instance.getEmbeddingRequest(); + if (version.supports(JINA_AI_EMBEDDING_TASK_ADDED) == false) { + embeddingRequest = new EmbeddingRequest(embeddingRequest.inputs(), embeddingRequest.inputType(), Map.of()); + } + + var context = instance.getContext(); if (version.supports(INFERENCE_CONTEXT) == false) { - return new EmbeddingAction.Request( - instance.getInferenceEntityId(), - instance.getTaskType(), - instance.getEmbeddingRequest(), - InferenceContext.EMPTY_INSTANCE, - instance.getTimeout() - ); + context = InferenceContext.EMPTY_INSTANCE; } - return instance; + + return new EmbeddingAction.Request( + instance.getInferenceEntityId(), + instance.getTaskType(), + embeddingRequest, + context, + instance.getTimeout() + ); } @Override @@ -160,7 +176,11 @@ public static EmbeddingAction.Request createRandom() { } private static EmbeddingRequest randomEmbeddingRequest() { - return new EmbeddingRequest(List.of(new InferenceStringGroup(randomAlphanumericOfLength(8))), randomFrom(InputType.values())); + return new EmbeddingRequest( + List.of(new InferenceStringGroup(randomAlphanumericOfLength(8))), + randomFrom(InputType.values()), + Map.of(randomAlphanumericOfLength(8), randomAlphanumericOfLength(8)) + ); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java index a7a6f78e26b7f..771512f0d1f13 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java @@ -15,14 +15,13 @@ import java.util.HashMap; import java.util.Map; +import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.DEFAULT_SETTINGS; import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.ELASTIC_RERANKER_EXTRA_TOKEN_COUNT; import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.ELASTIC_RERANKER_TOKEN_LIMIT; import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.WORDS_PER_TOKEN; public class ChunkingSettingsBuilderTests extends ESTestCase { - public static final SentenceBoundaryChunkingSettings DEFAULT_SETTINGS = new SentenceBoundaryChunkingSettings(250, 1); - public void testNullChunkingSettingsMap() { ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.fromMap(null); assertEquals(ChunkingSettingsBuilder.OLD_DEFAULT_SETTINGS, chunkingSettings); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index e356edec7d40c..edd31876f2735 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -232,7 +232,7 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { } public void testGetServicesWithEmbeddingTaskType() throws IOException { - assertThat(providersFor(TaskType.EMBEDDING), containsInAnyOrder(List.of("text_embedding_test_service").toArray())); + assertThat(providersFor(TaskType.EMBEDDING), containsInAnyOrder(List.of("text_embedding_test_service", "jinaai").toArray())); } private List getAllServices() throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index b5f2bc89e0ec8..ea666531bedf2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -106,8 +106,9 @@ import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; -import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAITextEmbeddingServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings; import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionServiceSettings; @@ -860,8 +861,15 @@ private static void addJinaAINamedWriteables(List namedWriteables.add( new NamedWriteableRegistry.Entry( ServiceSettings.class, - JinaAIEmbeddingsServiceSettings.NAME, - JinaAIEmbeddingsServiceSettings::new + JinaAITextEmbeddingServiceSettings.NAME, + JinaAITextEmbeddingServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + JinaAIEmbeddingServiceSettings.NAME, + JinaAIEmbeddingServiceSettings::new ) ); namedWriteables.add( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 06c6c6e3aeeac..1e3d178da9c2f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -43,6 +43,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceString; import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.MinimalServiceSettings; @@ -79,6 +80,7 @@ import java.util.Map; import java.util.stream.Collectors; +import static java.util.Collections.singletonList; import static org.elasticsearch.inference.telemetry.InferenceStats.serviceAndResponseAttributes; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy; @@ -396,8 +398,14 @@ private void executeChunkedInferenceAsync( } } + // This assumes that all inference requests are text only, with no images final List inputs = requests.stream() - .map(r -> new ChunkInferenceInput(new InferenceStringGroup(r.input), r.chunkingSettings)) + .map( + r -> new ChunkInferenceInput( + new InferenceStringGroup(singletonList(new InferenceString(InferenceString.DataType.TEXT, r.input))), + r.chunkingSettings + ) + ) .collect(Collectors.toList()); ActionListener> completionListener = ActionListener.wrap(results -> { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java index 6fed5218a93ad..bd87a04afab78 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java @@ -29,6 +29,11 @@ public final class ServiceFields { * The field used by services other than elasticsearch to determine the embedding type */ public static final String EMBEDDING_TYPE = "embedding_type"; + /** + * The name of the field used to specify whether the model supports multimodal inputs for the + * {@link org.elasticsearch.inference.TaskType#EMBEDDING} task type. Defaults to true. + */ + public static final String MULTIMODAL_MODEL = "multimodal_model"; private ServiceFields() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java index ab72aab578365..21f72872627c7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java @@ -11,7 +11,6 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; @@ -22,7 +21,6 @@ import org.elasticsearch.xpack.inference.services.jinaai.request.JinaAIEmbeddingsRequest; import org.elasticsearch.xpack.inference.services.jinaai.response.JinaAIEmbeddingsResponseEntity; -import java.util.List; import java.util.Objects; import java.util.function.Supplier; @@ -52,11 +50,11 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); - List docsInput = input.getTextInputs(); - InputType inputType = input.getInputType(); + var embeddingsInput = inferenceInputs.castTo(EmbeddingsInput.class); + var inferenceStringGroups = embeddingsInput.getInputs(); + var inputType = embeddingsInput.getInputType(); - JinaAIEmbeddingsRequest request = new JinaAIEmbeddingsRequest(docsInput, inputType, model); + var request = new JinaAIEmbeddingsRequest(inferenceStringGroups, inputType, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index 88372c2c70244..a3f3734d22576 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.jinaai; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterService; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmbeddingRequest; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; @@ -30,8 +32,10 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.core.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; @@ -43,7 +47,6 @@ import org.elasticsearch.xpack.inference.services.jinaai.action.JinaAIActionCreator; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -54,6 +57,7 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.inference.InferenceStringGroup.containsNonTextEntry; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceFields.EMBEDDING_TYPE; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; @@ -66,6 +70,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceFields.EMBEDDING_MAX_BATCH_SIZE; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.updateEmbeddingDetails; public class JinaAIService extends SenderService implements RerankingInferenceService { @@ -74,7 +79,7 @@ public class JinaAIService extends SenderService implements RerankingInferenceSe public static final String NAME = "jinaai"; private static final String SERVICE_NAME = "Jina AI"; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK); + private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.EMBEDDING); public static final EnumSet VALID_INPUT_TYPE_VALUES = EnumSet.of( InputType.INGEST, @@ -114,7 +119,7 @@ public void parseRequestConfig( Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + if (TaskType.TEXT_EMBEDDING.equals(taskType) || TaskType.EMBEDDING.equals(taskType)) { chunkingSettings = ChunkingSettingsBuilder.fromMap( removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) ); @@ -168,13 +173,14 @@ private static JinaAIModel createModel( ConfigurationParseContext context ) { return switch (taskType) { - case TEXT_EMBEDDING -> new JinaAIEmbeddingsModel( + case TEXT_EMBEDDING, EMBEDDING -> new JinaAIEmbeddingsModel( inferenceEntityId, serviceSettings, taskSettings, chunkingSettings, secretSettings, - context + context, + taskType ); case RERANK -> new JinaAIRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context); default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); @@ -193,7 +199,7 @@ public JinaAIModel parsePersistedConfigWithSecrets( Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + if (TaskType.TEXT_EMBEDDING.equals(taskType) || TaskType.EMBEDDING.equals(taskType)) { chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } @@ -213,7 +219,7 @@ public JinaAIModel parsePersistedConfig(String inferenceEntityId, TaskType taskT Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + if (TaskType.TEXT_EMBEDDING.equals(taskType) || TaskType.EMBEDDING.equals(taskType)) { chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } @@ -285,7 +291,8 @@ protected void doChunkedInfer( boolean batchChunksAcrossInputs = true; if (jinaaiModel.getTaskSettings() instanceof JinaAIEmbeddingsTaskSettings jinaAIEmbeddingsTaskSettings) { batchChunksAcrossInputs = jinaAIEmbeddingsTaskSettings.getLateChunking() == null - || jinaAIEmbeddingsTaskSettings.getLateChunking() == false; + || jinaAIEmbeddingsTaskSettings.getLateChunking() == false + || inputs.stream().anyMatch(c -> c.input().containsNonTextEntry()); } List batchedRequests = new EmbeddingRequestChunker<>( @@ -301,25 +308,40 @@ protected void doChunkedInfer( } } + @Override + protected void doEmbeddingInfer( + Model model, + EmbeddingRequest request, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof JinaAIEmbeddingsModel jinaAIModel) { + if (model.getServiceSettings().isMultimodal() == false && containsNonTextEntry(request.inputs())) { + listener.onFailure(new ElasticsearchStatusException("Non-text input provided for text-only model", RestStatus.BAD_REQUEST)); + } else { + var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents()); + + ExecutableAction action = jinaAIModel.accept(actionCreator, request.taskSettings()); + action.execute(new EmbeddingsInput(request::inputs, request.inputType()), timeout, listener); + } + } else { + listener.onFailure(createInvalidModelException(model)); + } + + } + @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof JinaAIEmbeddingsModel embeddingsModel) { var serviceSettings = embeddingsModel.getServiceSettings(); var similarityFromModel = serviceSettings.similarity(); var similarityToUse = similarityFromModel == null ? defaultSimilarity(serviceSettings.getEmbeddingType()) : similarityFromModel; - var maxInputTokens = serviceSettings.maxInputTokens(); - - var updatedServiceSettings = new JinaAIEmbeddingsServiceSettings( - new JinaAIServiceSettings( - serviceSettings.getCommonSettings().modelId(), - serviceSettings.getCommonSettings().rateLimitSettings() - ), - similarityToUse, - embeddingSize, - maxInputTokens, - serviceSettings.getEmbeddingType(), - serviceSettings.dimensionsSetByUser() - ); + + var updatedServiceSettings = updateEmbeddingDetails(serviceSettings, embeddingSize, similarityToUse); + + if (updatedServiceSettings.equals(serviceSettings)) { + return model; + } return new JinaAIEmbeddingsModel(embeddingsModel, updatedServiceSettings); } else { @@ -379,7 +401,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( DIMENSIONS, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription( + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.EMBEDDING)).setDescription( "The number of dimensions the resulting embeddings should have. For more information refer to " + "https://api.jina.ai/docs#tag/embeddings/operation/create_embedding_v1_embeddings_post." ) @@ -393,7 +415,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( EMBEDDING_TYPE, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription( + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.EMBEDDING)).setDescription( Strings.format( "The type of embedding to return. One of %s. bit and binary are equivalent and are encoded as " + "bytes with signed int8 precision.", @@ -411,7 +433,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( SIMILARITY, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription( + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.EMBEDDING)).setDescription( Strings.format( "The similarity measure. One of %s. For float embeddings, the default similarity " + "is dot_product. For bit and binary embeddings, the default similarity is l2_norm.", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettings.java similarity index 62% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettings.java index 134e1a0029d80..4cd742831afb8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettings.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; @@ -27,19 +28,28 @@ import java.util.Map; import java.util.Objects; +import static org.elasticsearch.inference.EmbeddingRequest.JINA_AI_EMBEDDING_TASK_ADDED; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS_SET_BY_USER; +import static org.elasticsearch.xpack.inference.services.ServiceFields.EMBEDDING_TYPE; import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MULTIMODAL_MODEL; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; -public class JinaAIEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { - public static final String NAME = "jinaai_embeddings_service_settings"; +public abstract class BaseJinaAIEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { - public static JinaAIEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED = TransportVersion.fromName("jina_ai_embedding_type_support_added"); + + static final TransportVersion JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED = TransportVersion.fromName( + "jina_ai_embedding_dimensions_support_added" + ); + + static BaseJinaAIEmbeddingsServiceSettings fromMap(Map map, TaskType taskType, ConfigurationParseContext context) { + Objects.requireNonNull(taskType); ValidationException validationException = new ValidationException(); var commonServiceSettings = JinaAIServiceSettings.fromMap(map, context); SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); @@ -58,25 +68,46 @@ public static JinaAIEmbeddingsServiceSettings fromMap(Map map, C dimensionsSetByUser = dimensions != null; } + Boolean multimodalModel = null; + // Do not remove the MULTIMODAL_MODEL field from the map for TEXT_EMBEDDING since it's not supported + if (taskType == TaskType.EMBEDDING) { + multimodalModel = removeAsType(map, MULTIMODAL_MODEL, Boolean.class); + if (multimodalModel == null) { + multimodalModel = true; + } + } + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new JinaAIEmbeddingsServiceSettings( - commonServiceSettings, - similarity, - dimensions, - maxInputTokens, - embeddingType, - dimensionsSetByUser - ); + if (taskType == TaskType.EMBEDDING) { + return new JinaAIEmbeddingServiceSettings( + commonServiceSettings, + similarity, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser, + multimodalModel + ); + } else { + return new JinaAITextEmbeddingServiceSettings( + commonServiceSettings, + similarity, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser + ); + } } static JinaAIEmbeddingType parseEmbeddingType(Map map, ValidationException validationException) { return Objects.requireNonNullElse( extractOptionalEnum( map, - ServiceFields.EMBEDDING_TYPE, + EMBEDDING_TYPE, ModelConfigurations.SERVICE_SETTINGS, JinaAIEmbeddingType::fromString, EnumSet.allOf(JinaAIEmbeddingType.class), @@ -86,28 +117,33 @@ static JinaAIEmbeddingType parseEmbeddingType(Map map, Validatio ); } - private static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED = TransportVersion.fromName( - "jina_ai_embedding_type_support_added" - ); - - static final TransportVersion JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED = TransportVersion.fromName( - "jina_ai_embedding_dimensions_support_added" - ); + public static BaseJinaAIEmbeddingsServiceSettings updateEmbeddingDetails( + BaseJinaAIEmbeddingsServiceSettings existingSettings, + Integer embeddingSize, + SimilarityMeasure similarityToUse + ) { + if (embeddingSize.equals(existingSettings.dimensions()) && similarityToUse.equals(existingSettings.similarity())) { + return existingSettings; + } + return existingSettings.update(similarityToUse, embeddingSize); + } private final JinaAIServiceSettings commonSettings; private final SimilarityMeasure similarity; private final Integer dimensions; private final Integer maxInputTokens; private final JinaAIEmbeddingType embeddingType; - private final Boolean dimensionsSetByUser; + private final boolean dimensionsSetByUser; + private final Boolean multimodalModel; - public JinaAIEmbeddingsServiceSettings( + public BaseJinaAIEmbeddingsServiceSettings( JinaAIServiceSettings commonSettings, @Nullable SimilarityMeasure similarity, @Nullable Integer dimensions, @Nullable Integer maxInputTokens, @Nullable JinaAIEmbeddingType embeddingType, - boolean dimensionsSetByUser + boolean dimensionsSetByUser, + @Nullable Boolean multimodalModel ) { this.commonSettings = commonSettings; this.similarity = similarity; @@ -115,25 +151,47 @@ public JinaAIEmbeddingsServiceSettings( this.maxInputTokens = maxInputTokens; this.embeddingType = embeddingType != null ? embeddingType : JinaAIEmbeddingType.FLOAT; this.dimensionsSetByUser = dimensionsSetByUser; + this.multimodalModel = multimodalModel; } - public JinaAIEmbeddingsServiceSettings(StreamInput in) throws IOException { + public BaseJinaAIEmbeddingsServiceSettings(StreamInput in) throws IOException { this.commonSettings = new JinaAIServiceSettings(in); this.similarity = in.readOptionalEnum(SimilarityMeasure.class); this.dimensions = in.readOptionalVInt(); this.maxInputTokens = in.readOptionalVInt(); - - this.embeddingType = (in.getTransportVersion().supports(JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED)) - ? Objects.requireNonNullElse(in.readOptionalEnum(JinaAIEmbeddingType.class), JinaAIEmbeddingType.FLOAT) - : JinaAIEmbeddingType.FLOAT; + if (in.getTransportVersion().supports(JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED)) { + this.embeddingType = Objects.requireNonNullElse(in.readOptionalEnum(JinaAIEmbeddingType.class), JinaAIEmbeddingType.FLOAT); + } else { + this.embeddingType = JinaAIEmbeddingType.FLOAT; + } if (in.getTransportVersion().supports(JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED)) { this.dimensionsSetByUser = in.readBoolean(); } else { this.dimensionsSetByUser = false; } + + if (in.getTransportVersion().supports(JINA_AI_EMBEDDING_TASK_ADDED)) { + this.multimodalModel = in.readOptionalBoolean(); + } else { + this.multimodalModel = null; + } } + /** + * Returns whether this {@link BaseJinaAIEmbeddingsServiceSettings} defaults to supporting multimodal inputs or not + * @return {@code true} if these settings default to supporting multimodal inputs + */ + public abstract boolean getDefaultMultimodal(); + + /** + * Returns a new {@link BaseJinaAIEmbeddingsServiceSettings} with updated similarity and dimensions but all other fields unchanged + * @param similarity the new similarity + * @param dimensions the new dimensions + * @return a new {@link BaseJinaAIEmbeddingsServiceSettings} + */ + public abstract BaseJinaAIEmbeddingsServiceSettings update(SimilarityMeasure similarity, Integer dimensions); + public JinaAIServiceSettings getCommonSettings() { return commonSettings; } @@ -172,8 +230,8 @@ public DenseVectorFieldMapper.ElementType elementType() { } @Override - public String getWriteableName() { - return NAME; + public boolean isMultimodal() { + return multimodalModel != null ? multimodalModel : getDefaultMultimodal(); } @Override @@ -200,10 +258,15 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil if (maxInputTokens != null) { builder.field(MAX_INPUT_TOKENS, maxInputTokens); } + if (similarity != null) { builder.field(SIMILARITY, similarity); } + if (multimodalModel != null) { + builder.field(MULTIMODAL_MODEL, multimodalModel); + } + return builder; } @@ -218,7 +281,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); out.writeOptionalVInt(dimensions); out.writeOptionalVInt(maxInputTokens); - if (out.getTransportVersion().supports(JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED)) { out.writeOptionalEnum(JinaAIEmbeddingType.translateToVersion(embeddingType, out.getTransportVersion())); } @@ -226,23 +288,48 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().supports(JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED)) { out.writeBoolean(dimensionsSetByUser); } + + if (out.getTransportVersion().supports(JINA_AI_EMBEDDING_TASK_ADDED)) { + out.writeOptionalBoolean(multimodalModel); + } } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - JinaAIEmbeddingsServiceSettings that = (JinaAIEmbeddingsServiceSettings) o; + BaseJinaAIEmbeddingsServiceSettings that = (BaseJinaAIEmbeddingsServiceSettings) o; return Objects.equals(commonSettings, that.commonSettings) && Objects.equals(similarity, that.similarity) && Objects.equals(dimensions, that.dimensions) && Objects.equals(maxInputTokens, that.maxInputTokens) && Objects.equals(embeddingType, that.embeddingType) - && Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser); + && Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser) + && Objects.equals(multimodalModel, that.multimodalModel); } @Override public int hashCode() { - return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens, embeddingType, dimensionsSetByUser); + return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens, embeddingType, dimensionsSetByUser, multimodalModel); + } + + @Override + public String toString() { + return "BaseJinaAIEmbeddingsServiceSettings{" + + "commonSettings=" + + commonSettings + + ", similarity=" + + similarity + + ", dimensions=" + + dimensions + + ", maxInputTokens=" + + maxInputTokens + + ", embeddingType=" + + embeddingType + + ", dimensionsSetByUser=" + + dimensionsSetByUser + + ", multimodalModel=" + + multimodalModel + + '}'; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettings.java new file mode 100644 index 0000000000000..2d3e97bd09ae3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettings.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.embeddings; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; + +import java.io.IOException; +import java.util.Map; + +public class JinaAIEmbeddingServiceSettings extends BaseJinaAIEmbeddingsServiceSettings { + public static final String NAME = "jinaai_multimodal_embedding_service_settings"; + + public static JinaAIEmbeddingServiceSettings fromMap(Map map, ConfigurationParseContext context) { + return (JinaAIEmbeddingServiceSettings) BaseJinaAIEmbeddingsServiceSettings.fromMap(map, TaskType.EMBEDDING, context); + } + + public JinaAIEmbeddingServiceSettings( + JinaAIServiceSettings commonSettings, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + @Nullable JinaAIEmbeddingType embeddingType, + boolean dimensionsSetByUser, + @Nullable Boolean multimodalModel + ) { + super(commonSettings, similarity, dimensions, maxInputTokens, embeddingType, dimensionsSetByUser, multimodalModel); + } + + public JinaAIEmbeddingServiceSettings(StreamInput in) throws IOException { + super(in); + } + + @Override + public boolean getDefaultMultimodal() { + return true; + } + + @Override + public BaseJinaAIEmbeddingsServiceSettings update(SimilarityMeasure similarity, Integer dimensions) { + return new JinaAIEmbeddingServiceSettings( + getCommonSettings(), + similarity, + dimensions, + maxInputTokens(), + getEmbeddingType(), + dimensionsSetByUser(), + isMultimodal() + ); + } + + @Override + public String getWriteableName() { + return NAME; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingType.java index 3bccc940b7fe3..5d22d93e77594 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingType.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingType.java @@ -16,6 +16,8 @@ import java.util.Locale; import java.util.Map; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED; + /** * Defines the type of embedding that the Jina AI API should return for a request. * @@ -49,10 +51,6 @@ private static final class RequestConstants { ELEMENT_TYPE_TO_JINA_AI_EMBEDDING.keySet() ); - private static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED = TransportVersion.fromName( - "jina_ai_embedding_type_support_added" - ); - private final DenseVectorFieldMapper.ElementType elementType; private final String requestString; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java index 95f7b570f673a..4c6a4751e60ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java @@ -47,29 +47,32 @@ public JinaAIEmbeddingsModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secrets, - ConfigurationParseContext context + ConfigurationParseContext context, + TaskType taskType ) { this( inferenceId, - JinaAIEmbeddingsServiceSettings.fromMap(serviceSettings, context), + BaseJinaAIEmbeddingsServiceSettings.fromMap(serviceSettings, taskType, context), JinaAIEmbeddingsTaskSettings.fromMap(taskSettings), chunkingSettings, DefaultSecretSettings.fromMap(secrets), - null + null, + taskType ); } // should only be used for testing JinaAIEmbeddingsModel( String modelId, - JinaAIEmbeddingsServiceSettings serviceSettings, + BaseJinaAIEmbeddingsServiceSettings serviceSettings, JinaAIEmbeddingsTaskSettings taskSettings, ChunkingSettings chunkingSettings, @Nullable DefaultSecretSettings secretSettings, - @Nullable String uri + @Nullable String uri, + TaskType taskType ) { super( - new ModelConfigurations(modelId, TaskType.TEXT_EMBEDDING, JinaAIService.NAME, serviceSettings, taskSettings, chunkingSettings), + new ModelConfigurations(modelId, taskType, JinaAIService.NAME, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secretSettings), secretSettings, serviceSettings.getCommonSettings(), @@ -81,13 +84,13 @@ private JinaAIEmbeddingsModel(JinaAIEmbeddingsModel model, JinaAIEmbeddingsTaskS super(model, taskSettings); } - public JinaAIEmbeddingsModel(JinaAIEmbeddingsModel model, JinaAIEmbeddingsServiceSettings serviceSettings) { + public JinaAIEmbeddingsModel(JinaAIEmbeddingsModel model, BaseJinaAIEmbeddingsServiceSettings serviceSettings) { super(model, serviceSettings); } @Override - public JinaAIEmbeddingsServiceSettings getServiceSettings() { - return (JinaAIEmbeddingsServiceSettings) super.getServiceSettings(); + public BaseJinaAIEmbeddingsServiceSettings getServiceSettings() { + return (BaseJinaAIEmbeddingsServiceSettings) super.getServiceSettings(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettings.java new file mode 100644 index 0000000000000..9ec949be0bc29 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettings.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.embeddings; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; + +import java.io.IOException; +import java.util.Map; + +public class JinaAITextEmbeddingServiceSettings extends BaseJinaAIEmbeddingsServiceSettings { + /** + * This name is a holdover from before the introduction of {@link JinaAIEmbeddingServiceSettings} to support multimodal embeddings + */ + public static final String NAME = "jinaai_embeddings_service_settings"; + + public static JinaAITextEmbeddingServiceSettings fromMap(Map map, ConfigurationParseContext context) { + return (JinaAITextEmbeddingServiceSettings) BaseJinaAIEmbeddingsServiceSettings.fromMap(map, TaskType.TEXT_EMBEDDING, context); + } + + public JinaAITextEmbeddingServiceSettings( + JinaAIServiceSettings commonServiceSettings, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dims, + @Nullable Integer maxInputTokens, + @Nullable JinaAIEmbeddingType embeddingTypes, + boolean dimensionsSetByUser + ) { + super(commonServiceSettings, similarity, dims, maxInputTokens, embeddingTypes, dimensionsSetByUser, null); + } + + public JinaAITextEmbeddingServiceSettings(StreamInput in) throws IOException { + super(in); + } + + @Override + public boolean getDefaultMultimodal() { + return false; + } + + @Override + public BaseJinaAIEmbeddingsServiceSettings update(SimilarityMeasure similarity, Integer dimensions) { + return new JinaAITextEmbeddingServiceSettings( + getCommonSettings(), + similarity, + dimensions, + maxInputTokens(), + getEmbeddingType(), + dimensionsSetByUser() + ); + } + + @Override + public String getWriteableName() { + return NAME; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequest.java index 7027533746884..e19d06e8fe08c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequest.java @@ -10,7 +10,9 @@ import org.apache.http.client.methods.HttpPost; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; @@ -23,11 +25,11 @@ public class JinaAIEmbeddingsRequest extends JinaAIRequest { - private final List input; + private final List input; private final InputType inputType; private final JinaAIEmbeddingsModel model; - public JinaAIEmbeddingsRequest(List input, InputType inputType, JinaAIEmbeddingsModel embeddingsModel) { + public JinaAIEmbeddingsRequest(List input, InputType inputType, JinaAIEmbeddingsModel embeddingsModel) { this.input = Objects.requireNonNull(input); this.inputType = inputType; this.model = Objects.requireNonNull(embeddingsModel); @@ -70,4 +72,8 @@ public boolean[] getTruncationInfo() { public JinaAIEmbeddingType getEmbeddingType() { return model.getServiceSettings().getEmbeddingType(); } + + public TaskType getTaskType() { + return model.getTaskType(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java index c68731882ffc6..fdb0997bf129b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.jinaai.request; +import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -18,9 +19,10 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.inference.InferenceStringGroup.toStringList; import static org.elasticsearch.inference.InputType.invalidInputTypeMessage; -public record JinaAIEmbeddingsRequestEntity(List input, InputType inputType, JinaAIEmbeddingsModel model) +public record JinaAIEmbeddingsRequestEntity(List input, InputType inputType, JinaAIEmbeddingsModel model) implements ToXContentObject { @@ -29,6 +31,8 @@ public record JinaAIEmbeddingsRequestEntity(List input, InputType inputT private static final String CLUSTERING = "separation"; private static final String CLASSIFICATION = "classification"; private static final String INPUT_FIELD = "input"; + private static final String INPUT_TEXT_FIELD = "text"; + private static final String INPUT_IMAGE_FIELD = "image"; private static final String MODEL_FIELD = "model"; public static final String TASK_TYPE_FIELD = "task"; public static final String LATE_CHUNKING = "late_chunking"; @@ -46,7 +50,7 @@ public record JinaAIEmbeddingsRequestEntity(List input, InputType inputT @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(INPUT_FIELD, input); + writeInputs(builder); builder.field(MODEL_FIELD, model.getServiceSettings().modelId()); builder.field(EMBEDDING_TYPE_FIELD, model.getServiceSettings().getEmbeddingType().toRequestString()); @@ -60,7 +64,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } if (taskSettings.getLateChunking() != null) { - builder.field(LATE_CHUNKING, taskSettings.getLateChunking() && getInputWordCount() <= MAX_WORD_COUNT_FOR_LATE_CHUNKING); + builder.field( + LATE_CHUNKING, + // Late chunking is not supported for image inputs + taskSettings.getLateChunking() + && InferenceStringGroup.containsNonTextEntry(input) == false + && getInputWordCount() <= MAX_WORD_COUNT_FOR_LATE_CHUNKING + ); } if (model.getServiceSettings().dimensionsSetByUser() && model.getServiceSettings().dimensions() != null) { @@ -71,6 +81,26 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } + private void writeInputs(XContentBuilder builder) throws IOException { + if (model.getServiceSettings().isMultimodal()) { + builder.startArray(INPUT_FIELD); + for (var inferenceStringGroup : input) { + for (var inferenceString : inferenceStringGroup.inferenceStrings()) { + builder.startObject(); + if (inferenceString.isText()) { + builder.field(INPUT_TEXT_FIELD, inferenceString.value()); + } else if (inferenceString.isImage()) { + builder.field(INPUT_IMAGE_FIELD, inferenceString.value()); + } + builder.endObject(); + } + } + builder.endArray(); + } else { + builder.field(INPUT_FIELD, toStringList(input)); + } + } + // default for testing static String convertInputType(InputType inputType) { return switch (inputType) { @@ -87,8 +117,8 @@ static String convertInputType(InputType inputType) { private int getInputWordCount() { int wordCount = 0; - for (var text : input) { - wordCount += ChunkerUtils.countWords(text); + for (var inferenceStringGroup : input) { + wordCount += ChunkerUtils.countWords(inferenceStringGroup.textValue()); } return wordCount; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java index 4471012b6e57c..ed35f50ae820b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java @@ -7,10 +7,11 @@ package org.elasticsearch.xpack.inference.services.jinaai.response; +import org.elasticsearch.common.CheckedBiFunction; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; -import org.elasticsearch.core.CheckedFunction; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; @@ -18,6 +19,10 @@ import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.EmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.EmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.GenericDenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.GenericDenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.XContentUtils; @@ -39,14 +44,15 @@ public class JinaAIEmbeddingsResponseEntity { private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI embeddings response"; - private static final Map> EMBEDDING_PARSERS = Map.of( - toLowerCase(JinaAIEmbeddingType.FLOAT), - JinaAIEmbeddingsResponseEntity::parseFloatDataObject, - toLowerCase(JinaAIEmbeddingType.BIT), - JinaAIEmbeddingsResponseEntity::parseBitDataObject, - toLowerCase(JinaAIEmbeddingType.BINARY), - JinaAIEmbeddingsResponseEntity::parseBitDataObject - ); + private static final Map> EMBEDDING_PARSERS = + Map.of( + toLowerCase(JinaAIEmbeddingType.FLOAT), + JinaAIEmbeddingsResponseEntity::parseFloatDataObject, + toLowerCase(JinaAIEmbeddingType.BIT), + JinaAIEmbeddingsResponseEntity::parseBitDataObject, + toLowerCase(JinaAIEmbeddingType.BINARY), + JinaAIEmbeddingsResponseEntity::parseBitDataObject + ); private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes(); private static String supportedEmbeddingTypes() { @@ -112,6 +118,7 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r ); } + var taskType = embeddingsRequest.getTaskType(); var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { moveToFirstToken(jsonParser); @@ -121,20 +128,26 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); - return embeddingValueParser.apply(jsonParser); + return embeddingValueParser.apply(jsonParser, taskType); } } - private static InferenceServiceResults parseFloatDataObject(XContentParser jsonParser) throws IOException { - List embeddingList = parseList( + private static InferenceServiceResults parseFloatDataObject(XContentParser jsonParser, TaskType taskType) throws IOException { + List embeddingList = parseList( jsonParser, JinaAIEmbeddingsResponseEntity::parseFloatEmbeddingObject ); - return new DenseEmbeddingFloatResults(embeddingList); + if (taskType == TaskType.TEXT_EMBEDDING) { + return new DenseEmbeddingFloatResults(embeddingList); + } else if (taskType == TaskType.EMBEDDING) { + return new GenericDenseEmbeddingFloatResults(embeddingList); + } else { + throw new IllegalArgumentException("Invalid taskType: " + taskType); + } } - private static DenseEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException { + private static EmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -143,19 +156,25 @@ private static DenseEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XC // parse and discard the rest of the object consumeUntilObjectEnd(parser); - return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); + return EmbeddingFloatResults.Embedding.of(embeddingValuesList); } - private static InferenceServiceResults parseBitDataObject(XContentParser jsonParser) throws IOException { + private static InferenceServiceResults parseBitDataObject(XContentParser jsonParser, TaskType taskType) throws IOException { List embeddingList = parseList( jsonParser, JinaAIEmbeddingsResponseEntity::parseBitEmbeddingObject ); - return new DenseEmbeddingBitResults(embeddingList); + if (taskType == TaskType.TEXT_EMBEDDING) { + return new DenseEmbeddingBitResults(embeddingList); + } else if (taskType == TaskType.EMBEDDING) { + return new GenericDenseEmbeddingBitResults(embeddingList); + } else { + throw new IllegalArgumentException("Invalid taskType: " + taskType); + } } - private static DenseEmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException { + private static EmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -164,7 +183,7 @@ private static DenseEmbeddingByteResults.Embedding parseBitEmbeddingObject(XCont // parse and discard the rest of the object consumeUntilObjectEnd(parser); - return DenseEmbeddingByteResults.Embedding.of(embeddingList); + return EmbeddingByteResults.Embedding.of(embeddingList); } private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidator.java index 9140795d3d796..f4d2610688922 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidator.java @@ -23,6 +23,7 @@ import org.elasticsearch.rest.RestStatus; import java.util.List; +import java.util.Map; public class SimpleEmbeddingServiceIntegrationValidator implements ServiceIntegrationValidator { // The below data URI represents the base64 encoding of 28x28 pixel black square .jpg image @@ -44,7 +45,13 @@ public class SimpleEmbeddingServiceIntegrationValidator implements ServiceIntegr @Override public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { - EmbeddingRequest request = new EmbeddingRequest(List.of(TEST_TEXT_INPUT, TEST_IMAGE_BASE64_INPUT), InputType.INTERNAL_INGEST); + List inputList; + if (model.getServiceSettings().isMultimodal()) { + inputList = List.of(TEST_TEXT_INPUT, TEST_IMAGE_BASE64_INPUT); + } else { + inputList = List.of(TEST_TEXT_INPUT); + } + EmbeddingRequest request = new EmbeddingRequest(inputList, InputType.INTERNAL_INGEST, Map.of()); service.embeddingInfer(model, request, timeout, ActionListener.wrap(r -> { if (r != null) { listener.onResponse(r); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java index 9f2055d825fd7..4da01e0aca8af 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java @@ -44,6 +44,14 @@ public static InputType randomWithNull() { ); } + public static InputType randomRequestType() { + return randomFrom(InputType.INGEST, InputType.SEARCH, InputType.CLUSTERING, InputType.CLASSIFICATION); + } + + public static InputType randomRequestTypeWithNull() { + return randomFrom(InputType.INGEST, InputType.SEARCH, InputType.CLUSTERING, InputType.CLASSIFICATION, null); + } + public static InputType randomSearchAndIngestWithNull() { return randomBoolean() ? null diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/TaskTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/TaskTypeTests.java index f6c058bdbb79f..4d8244c22c83c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/TaskTypeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/TaskTypeTests.java @@ -24,4 +24,7 @@ public void testFromStringOrStatusException() { assertThat(TaskType.fromStringOrStatusException("any"), Matchers.is(TaskType.ANY)); } + public static TaskType randomEmbeddingTaskType() { + return randomFrom(TaskType.TEXT_EMBEDDING, TaskType.EMBEDDING); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index cfd9216e17a07..c232f92bc50f8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.Model; @@ -155,12 +156,19 @@ public record PersistedConfig(Map config, Map se public static PersistedConfig getPersistedConfigMap( Map serviceSettings, Map taskSettings, - Map chunkingSettings, - Map secretSettings + @Nullable Map chunkingSettings, + @Nullable Map secretSettings ) { + var secrets = secretSettings == null ? null : new HashMap(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)); + + var persistedConfigMap = new PersistedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + secrets + ); - var persistedConfigMap = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - persistedConfigMap.config.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + if (chunkingSettings != null) { + persistedConfigMap.config.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + } return persistedConfigMap; } @@ -168,27 +176,19 @@ public static PersistedConfig getPersistedConfigMap( public static PersistedConfig getPersistedConfigMap( Map serviceSettings, Map taskSettings, - Map secretSettings + @Nullable Map secretSettings ) { - var secrets = secretSettings == null ? null : new HashMap(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)); + return getPersistedConfigMap(serviceSettings, taskSettings, null, secretSettings); + } - return new PersistedConfig( - new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), - secrets - ); + public static PersistedConfig getPersistedConfigMap(Map serviceSettings, Map taskSettings) { + return Utils.getPersistedConfigMap(serviceSettings, taskSettings, null); } public static PersistedConfig getPersistedConfigMap(Map serviceSettings) { return Utils.getPersistedConfigMap(serviceSettings, new HashMap<>(), null); } - public static PersistedConfig getPersistedConfigMap(Map serviceSettings, Map taskSettings) { - return new PersistedConfig( - new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), - null - ); - } - public static Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettingsTests.java index 3e89c0c0425d9..fdec51a1a57d0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettingsTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; @@ -23,6 +24,7 @@ import java.util.HashMap; import java.util.Map; +import static org.elasticsearch.xpack.inference.services.settings.RateLimitSettings.REQUESTS_PER_MINUTE_FIELD; import static org.hamcrest.Matchers.is; public class JinaAIServiceSettingsTests extends AbstractBWCWireSerializationTestCase { @@ -104,11 +106,15 @@ protected JinaAIServiceSettings mutateInstance(JinaAIServiceSettings instance) t return new JinaAIServiceSettings(modelId, rateLimitSettings); } - public static Map getServiceSettingsMap(String model) { + public static Map getServiceSettingsMap(String model, @Nullable Integer requestsPerMinute) { var map = new HashMap(); map.put(ServiceFields.MODEL_ID, model); + if (requestsPerMinute != null) { + map.put(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(REQUESTS_PER_MINUTE_FIELD, requestsPerMinute))); + } + return map; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index 4670d32b8db61..424553d0a39e8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -18,27 +18,37 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; +import org.elasticsearch.inference.EmbeddingRequest; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InferenceString; +import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsOptions; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.EmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.GenericDenseEmbeddingFloatResultsTests; +import org.elasticsearch.xpack.inference.InputTypeTests; +import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -47,10 +57,14 @@ import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests; -import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettingsTests; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModelTests; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.CoreMatchers; import org.hamcrest.Matchers; @@ -61,23 +75,36 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.TimeUnit; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.inference.InferenceString.DataFormat.BASE64; +import static org.elasticsearch.inference.InferenceString.DataType.IMAGE; +import static org.elasticsearch.inference.ModelConfigurations.SERVICE_SETTINGS; +import static org.elasticsearch.inference.TaskType.EMBEDDING; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.DEFAULT_SETTINGS; +import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.OLD_DEFAULT_SETTINGS; import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; -import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.inference.TaskTypeTests.randomEmbeddingTaskType; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.Utils.randomSimilarityMeasure; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MULTIMODAL_MODEL; +import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS; +import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettingsTests.getServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettingsTests.getMapOfCommonEmbeddingSettings; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettingsTests.getMapOfMinimalEmbeddingSettings; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.sameInstance; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -92,8 +119,8 @@ @SuppressWarnings("resource") public class JinaAIServiceTests extends InferenceServiceTestCase { - private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); public static final String DEFAULT_EMBEDDING_URL = "https://api.jina.ai/v1/embeddings"; + public static final String DEFAULT_RERANK_URL = "https://api.jina.ai/v1/rerank"; private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; @@ -112,188 +139,273 @@ public void shutdown() throws IOException { webServer.close(); } - public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModel() throws IOException { + public void testParseRequestConfig_createsEmbeddingsModel_textEmbeddingTask() throws IOException { + testParseRequestConfig_createsEmbeddingModel(TEXT_EMBEDDING); + } + + public void testParseRequestConfig_createsEmbeddingsModel_embeddingTask() throws IOException { + testParseRequestConfig_createsEmbeddingModel(EMBEDDING); + } + + private void testParseRequestConfig_createsEmbeddingModel(TaskType taskType) throws IOException { try (var service = createJinaAIService()) { - ActionListener modelListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var modelName = randomAlphanumericOfLength(8); + var requestsPerMinute = randomNonNegativeInt(); + var similarity = randomSimilarityMeasure(); + var dimensions = randomNonNegativeInt(); + var maxInputTokens = randomNonNegativeInt(); + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var multimodalModel = taskType == EMBEDDING && randomBoolean(); + var inputType = InputTypeTests.randomRequestType(); + var lateChunking = randomBoolean(); + var apiKey = randomAlphanumericOfLength(8); + var chunkingSettings = createRandomChunkingSettings(); + + var serviceSettingsMap = getMapOfCommonEmbeddingSettings( + modelName, + similarity, + dimensions, + null, + maxInputTokens, + embeddingType, + requestsPerMinute + ); + + if (taskType == EMBEDDING) { + serviceSettingsMap.put(MULTIMODAL_MODEL, multimodalModel); + } - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(JinaAIEmbeddingType.FLOAT)); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + var modelListener = new PlainActionFuture(); service.parseRequestConfig( "id", - TaskType.TEXT_EMBEDDING, + taskType, getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), - getSecretSettingsMap("secret") + serviceSettingsMap, + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(inputType, lateChunking), + chunkingSettings.asMap(), + getSecretSettingsMap(apiKey) ), modelListener ); + assertEmbeddingModelSettings( + modelListener.actionGet(), + modelName, + new RateLimitSettings(requestsPerMinute), + similarity, + dimensions, + true, + maxInputTokens, + embeddingType, + multimodalModel, + new JinaAIEmbeddingsTaskSettings(inputType, lateChunking), + chunkingSettings, + apiKey + ); } } - public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + public void testParseRequestConfig_createsRerankModel() throws IOException { try (var service = createJinaAIService()) { - ActionListener modelListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(JinaAIEmbeddingType.FLOAT)); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + var modelName = randomAlphanumericOfLength(8); + var requestsPerMinute = randomNonNegativeInt(); + var topN = randomNonNegativeInt(); + var returnDocuments = randomBoolean(); + var apiKey = randomAlphanumericOfLength(8); + + var modelListener = new PlainActionFuture(); service.parseRequestConfig( "id", - TaskType.TEXT_EMBEDDING, + TaskType.RERANK, getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), - createRandomChunkingSettingsMap(), - getSecretSettingsMap("secret") + JinaAIRerankServiceSettingsTests.getServiceSettingsMap(modelName, requestsPerMinute), + JinaAIRerankTaskSettingsTests.getTaskSettingsMap(topN, returnDocuments), + getSecretSettingsMap(apiKey) ), modelListener ); + assertRerankModelSettings( + modelListener.actionGet(), + modelName, + new RateLimitSettings(requestsPerMinute), + apiKey, + new JinaAIRerankTaskSettings(topN, returnDocuments) + ); } } - public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + public void testParseRequestConfig_onlyRequiredSettings_createsEmbeddingModel_textEmbedding() throws IOException { + testParseRequestConfig_onlyRequiredSettings_createsEmbeddingModel(TEXT_EMBEDDING); + } + + public void testParseRequestConfig_onlyRequiredSettings_createsEmbeddingModel_embedding() throws IOException { + testParseRequestConfig_onlyRequiredSettings_createsEmbeddingModel(EMBEDDING); + } + + private void testParseRequestConfig_onlyRequiredSettings_createsEmbeddingModel(TaskType taskType) throws IOException { try (var service = createJinaAIService()) { - ActionListener modelListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(JinaAIEmbeddingType.BIT)); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + var modelName = randomAlphanumericOfLength(8); + var apiKey = randomAlphanumericOfLength(8); + + var modelListener = new PlainActionFuture(); service.parseRequestConfig( "id", - TaskType.TEXT_EMBEDDING, + taskType, getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.BIT), - JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), - getSecretSettingsMap("secret") + getMapOfCommonEmbeddingSettings(modelName, null, null, null, null, null, null), + getSecretSettingsMap(apiKey) ), modelListener ); + assertEmbeddingModelSettings( + modelListener.actionGet(), + modelName, + DEFAULT_RATE_LIMIT_SETTINGS, + null, + null, + false, + null, + JinaAIEmbeddingType.FLOAT, + taskType == EMBEDDING, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + DEFAULT_SETTINGS, + apiKey + ); } } - public void testParseRequestConfig_OptionalTaskSettings() throws IOException { + public void testParseRequestConfig_onlyRequiredSettings_createsRerankModel() throws IOException { try (var service = createJinaAIService()) { + var modelName = randomAlphanumericOfLength(8); + var apiKey = randomAlphanumericOfLength(8); - ActionListener modelListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), equalTo(JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + var modelListener = new PlainActionFuture(); service.parseRequestConfig( "id", - TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - getSecretSettingsMap("secret") - ), + TaskType.RERANK, + getRequestConfigMap(JinaAIRerankServiceSettingsTests.getServiceSettingsMap(modelName), getSecretSettingsMap(apiKey)), modelListener ); + assertRerankModelSettings( + modelListener.actionGet(), + modelName, + DEFAULT_RATE_LIMIT_SETTINGS, + apiKey, + JinaAIRerankTaskSettings.EMPTY_SETTINGS + ); + } } - public void testParseRequestConfig_ThrowsUnsupportedTaskType() throws IOException { + public void testParseRequestConfig_ThrowsErrorWithUnsupportedTaskType() throws IOException { try (var service = createJinaAIService()) { - var failureListener = getModelListenerForStatusException("The [jinaai] service does not support task type [sparse_embedding]"); + var unsupportedTaskType = randomValueOtherThanMany( + t -> service.supportedTaskTypes().contains(t), + () -> randomFrom(TaskType.values()) + ); + var failureListener = getModelListenerForStatusException( + Strings.format("The [jinaai] service does not support task type [%s]", unsupportedTaskType) + ); service.parseRequestConfig( "id", - TaskType.SPARSE_EMBEDDING, - getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - getSecretSettingsMap("secret") - ), + unsupportedTaskType, + getRequestConfigMap(getMapOfMinimalEmbeddingSettings("model"), getSecretSettingsMap("secret")), failureListener ); } } - private static ActionListener getModelListenerForStatusException(String expectedMessage) { - return ActionListener.wrap((model) -> fail("Model parsing should have failed"), e -> { - assertThat(e, instanceOf(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is(expectedMessage)); - }); - } - public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createJinaAIService()) { - var config = getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - getSecretSettingsMap("secret") - ); + var config = getRequestConfigMap(getMapOfMinimalEmbeddingSettings("model"), getSecretSettingsMap("secret")); config.put("extra_key", "value"); var failureListener = getModelListenerForStatusException( "Configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + service.parseRequestConfig("id", randomFrom(service.supportedTaskTypes()), config, failureListener); } } public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { try (var service = createJinaAIService()) { - var serviceSettings = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT); + var serviceSettings = getMapOfMinimalEmbeddingSettings("model"); serviceSettings.put("extra_key", "value"); - var config = getRequestConfigMap(serviceSettings, Map.of(), getSecretSettingsMap("secret")); + var config = getRequestConfigMap(serviceSettings, getSecretSettingsMap("secret")); var failureListener = getModelListenerForStatusException( "Configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + service.parseRequestConfig("id", randomFrom(service.supportedTaskTypes()), config, failureListener); + } + } + + public void testParseRequestConfig_textEmbedding_throwsWhenMultimodalModelKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createJinaAIService()) { + var serviceSettings = getMapOfMinimalEmbeddingSettings("model"); + serviceSettings.put(MULTIMODAL_MODEL, true); + + var config = getRequestConfigMap(serviceSettings, getSecretSettingsMap("secret")); + + var failureListener = getModelListenerForStatusException( + "Configuration contains settings [{multimodal_model=true}] unknown to the [jinaai] service" + ); + service.parseRequestConfig("id", TEXT_EMBEDDING, config, failureListener); + } + } + + public void testParseRequestConfig_embedding_doesNotThrowWhenMultimodalModelKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createJinaAIService()) { + String modelName = "model"; + var serviceSettings = getMapOfMinimalEmbeddingSettings(modelName); + var multimodalModel = randomBoolean(); + serviceSettings.put(MULTIMODAL_MODEL, multimodalModel); + + String apiKey = "secret"; + var config = getRequestConfigMap(serviceSettings, getSecretSettingsMap(apiKey)); + + var modelListener = new PlainActionFuture(); + service.parseRequestConfig("id", EMBEDDING, config, modelListener); + + assertEmbeddingModelSettings( + modelListener.actionGet(), + modelName, + DEFAULT_RATE_LIMIT_SETTINGS, + null, + null, + false, + null, + JinaAIEmbeddingType.FLOAT, + multimodalModel, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + DEFAULT_SETTINGS, + apiKey + ); } } public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { try (var service = createJinaAIService()) { - var taskSettingsMap = JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST); - taskSettingsMap.put("extra_key", "value"); var config = getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - taskSettingsMap, + getMapOfMinimalEmbeddingSettings("model"), + new HashMap<>(Map.of("extra_key", "value")), getSecretSettingsMap("secret") ); var failureListener = getModelListenerForStatusException( "Configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); - + service.parseRequestConfig("id", randomFrom(service.supportedTaskTypes()), config, failureListener); } } @@ -302,439 +414,440 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var secretSettingsMap = getSecretSettingsMap("secret"); secretSettingsMap.put("extra_key", "value"); - var config = getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - secretSettingsMap - ); + var config = getRequestConfigMap(getMapOfMinimalEmbeddingSettings("model"), secretSettingsMap); var failureListener = getModelListenerForStatusException( "Configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + service.parseRequestConfig("id", randomFrom(service.supportedTaskTypes()), config, failureListener); } } - public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModel() throws IOException { + public void testParsePersistedConfigWithSecrets_createsEmbeddingsModel_textEmbedding() throws IOException { + testParsePersistedConfigWithSecrets_createsEmbeddingModel(TEXT_EMBEDDING); + } + + public void testParsePersistedConfigWithSecrets_createsEmbeddingsModel_embedding() throws IOException { + testParsePersistedConfigWithSecrets_createsEmbeddingModel(EMBEDDING); + } + + private void testParsePersistedConfigWithSecrets_createsEmbeddingModel(TaskType taskType) throws IOException { try (var service = createJinaAIService()) { - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - getSecretSettingsMap("secret") + var modelName = randomAlphanumericOfLength(8); + var requestsPerMinute = randomNonNegativeInt(); + var similarity = randomSimilarityMeasure(); + var dimensions = randomNonNegativeInt(); + var dimensionsSetByUser = randomBoolean(); + var maxInputTokens = randomNonNegativeInt(); + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var multimodalModel = taskType == EMBEDDING && randomBoolean(); + var inputType = InputTypeTests.randomRequestType(); + var lateChunking = randomBoolean(); + var apiKey = randomAlphanumericOfLength(8); + var chunkingSettings = createRandomChunkingSettings(); + + var serviceSettingsMap = getMapOfCommonEmbeddingSettings( + modelName, + similarity, + dimensions, + dimensionsSetByUser, + maxInputTokens, + embeddingType, + requestsPerMinute ); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + if (taskType == EMBEDDING) { + serviceSettingsMap.put(MULTIMODAL_MODEL, multimodalModel); + } + + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(inputType, lateChunking), + chunkingSettings.asMap(), + getSecretSettingsMap(apiKey) ); - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var model = service.parsePersistedConfigWithSecrets("id", taskType, persistedConfig.config(), persistedConfig.secrets()); - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertEmbeddingModelSettings( + model, + modelName, + new RateLimitSettings(requestsPerMinute), + similarity, + dimensions, + dimensionsSetByUser, + maxInputTokens, + embeddingType, + multimodalModel, + new JinaAIEmbeddingsTaskSettings(inputType, lateChunking), + chunkingSettings, + apiKey + ); } } - public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + public void testParsePersistedConfigWithSecrets_createsRerankModel() throws IOException { try (var service = createJinaAIService()) { + var modelName = randomAlphanumericOfLength(8); + var requestsPerMinute = randomNonNegativeInt(); + var topN = randomNonNegativeInt(); + var returnDocuments = randomBoolean(); + var apiKey = randomAlphanumericOfLength(8); + var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - createRandomChunkingSettingsMap(), - getSecretSettingsMap("secret") + JinaAIRerankServiceSettingsTests.getServiceSettingsMap(modelName, requestsPerMinute), + JinaAIRerankTaskSettingsTests.getTaskSettingsMap(topN, returnDocuments), + getSecretSettingsMap(apiKey) ); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, persistedConfig.config(), persistedConfig.secrets()); + + assertRerankModelSettings( + model, + modelName, + new RateLimitSettings(requestsPerMinute), + apiKey, + new JinaAIRerankTaskSettings(topN, returnDocuments) ); + } + } - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + public void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmbeddingsModel_textEmbedding() throws IOException { + testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmbeddingModel(TEXT_EMBEDDING); + } - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } + public void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmbeddingsModel_embedding() throws IOException { + testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmbeddingModel(EMBEDDING); } - public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + private void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmbeddingModel(TaskType taskType) throws IOException { try (var service = createJinaAIService()) { + var modelName = randomAlphanumericOfLength(8); + var apiKey = randomAlphanumericOfLength(8); + Map chunkingSettingsMap = randomBoolean() ? Map.of() : null; + var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), + getMapOfMinimalEmbeddingSettings(modelName), Map.of(), - getSecretSettingsMap("secret") + chunkingSettingsMap, + getSecretSettingsMap(apiKey) ); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfigWithSecrets("id", taskType, persistedConfig.config(), persistedConfig.secrets()); + + assertEmbeddingModelSettings( + model, + modelName, + DEFAULT_RATE_LIMIT_SETTINGS, + null, + null, + false, + null, + JinaAIEmbeddingType.FLOAT, + taskType == EMBEDDING, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + chunkingSettingsMap == null ? OLD_DEFAULT_SETTINGS : DEFAULT_SETTINGS, + apiKey ); + } + } + + public void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsRerankModel() throws IOException { + try (var service = createJinaAIService()) { + var modelName = randomAlphanumericOfLength(8); + var apiKey = randomAlphanumericOfLength(8); + + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(modelName, null), Map.of(), getSecretSettingsMap(apiKey)); - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, persistedConfig.config(), persistedConfig.secrets()); - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertRerankModelSettings(model, modelName, DEFAULT_RATE_LIMIT_SETTINGS, apiKey, JinaAIRerankTaskSettings.EMPTY_SETTINGS); } } - public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + public void testParsePersistedConfigWithSecrets_ThrowsErrorWithUnsupportedTaskType() throws IOException { try (var service = createJinaAIService()) { - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("oldmodel", JinaAIEmbeddingType.FLOAT), - Map.of(), - getSecretSettingsMap("secret") + var unsupportedTaskType = randomValueOtherThanMany( + t -> service.supportedTaskTypes().contains(t), + () -> randomFrom(TaskType.values()) ); + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("oldmodel", null), Map.of(), getSecretSettingsMap("secret")); var thrownException = expectThrows( ElasticsearchStatusException.class, () -> service.parsePersistedConfigWithSecrets( "id", - TaskType.SPARSE_EMBEDDING, + unsupportedTaskType, persistedConfig.config(), persistedConfig.secrets() ) ); assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [jinaai] service")); - assertThat(thrownException.getMessage(), containsString("The [jinaai] service does not support task type [sparse_embedding]")); + assertThat( + thrownException.getMessage(), + containsString(Strings.format("The [jinaai] service does not support task type [%s]", unsupportedTaskType)) + ); } } public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createJinaAIService()) { + String modelName = "modelName"; + String apiKey = "secret"; var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), + getServiceSettingsMap(modelName, null), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH), - getSecretSettingsMap("secret") + getSecretSettingsMap(apiKey) ); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.SEARCH))); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertParsePersistedConfigWithSecretsMinimalSettings(service, persistedConfig, modelName, apiKey); } } public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { try (var service = createJinaAIService()) { - var secretSettingsMap = getSecretSettingsMap("secret"); + String modelName = "modelName"; + String apiKey = "secret"; + var secretSettingsMap = getSecretSettingsMap(apiKey); secretSettingsMap.put("extra_key", "value"); - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - secretSettingsMap - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(modelName, null), Map.of(), secretSettingsMap); - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertParsePersistedConfigWithSecretsMinimalSettings(service, persistedConfig, modelName, apiKey); } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createJinaAIService()) { - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - getSecretSettingsMap("secret") - ); - persistedConfig.secrets().put("extra_key", "value"); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); + String modelName = "modelName"; + String apiKey = "secret"; + var serviceSettingsMap = getServiceSettingsMap(modelName, null); + serviceSettingsMap.put("extra_key", "value"); - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var persistedConfig = getPersistedConfigMap(serviceSettingsMap, Map.of(), getSecretSettingsMap("secret")); - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertParsePersistedConfigWithSecretsMinimalSettings(service, persistedConfig, modelName, apiKey); } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createJinaAIService()) { - var serviceSettingsMap = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT); - serviceSettingsMap.put("extra_key", "value"); + String modelName = "modelName"; + String apiKey = "secret"; - var persistedConfig = getPersistedConfigMap(serviceSettingsMap, Map.of(), getSecretSettingsMap("secret")); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap(modelName, null), + new HashMap<>(Map.of("extra_key", "value")), + getSecretSettingsMap(apiKey) ); - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertParsePersistedConfigWithSecretsMinimalSettings(service, persistedConfig, modelName, apiKey); } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInChunkingSettings() throws IOException { try (var service = createJinaAIService()) { - var taskSettingsMap = JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH); - taskSettingsMap.put("extra_key", "value"); + String modelName = "modelName"; + String apiKey = "secret"; var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - taskSettingsMap, - getSecretSettingsMap("secret") + getServiceSettingsMap(modelName, null), + Map.of(), + Map.of(ChunkingSettingsOptions.STRATEGY.toString(), ChunkingStrategy.NONE.toString(), "extra_key", "value"), + getSecretSettingsMap(apiKey) ); var model = service.parsePersistedConfigWithSecrets( "id", - TaskType.TEXT_EMBEDDING, + randomEmbeddingTaskType(), persistedConfig.config(), persistedConfig.secrets() ); - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.SEARCH))); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(model.getServiceSettings().modelId(), is(modelName)); + assertThat(model.apiKey().toString(), is(apiKey)); } } - public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModel() throws IOException { - try (var service = createJinaAIService()) { - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of() - ); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + public void testParsePersistedConfig_createsEmbeddingsModel_textEmbedding() throws IOException { + testParsePersistedConfig_createsEmbeddingModel(TEXT_EMBEDDING); + } - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); - assertNull(embeddingsModel.getSecretSettings()); - } + public void testParsePersistedConfig_createsEmbeddingsModel_embedding() throws IOException { + testParsePersistedConfig_createsEmbeddingModel(EMBEDDING); } - public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + private void testParsePersistedConfig_createsEmbeddingModel(TaskType taskType) throws IOException { try (var service = createJinaAIService()) { - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - createRandomChunkingSettingsMap() + var modelName = randomAlphanumericOfLength(8); + var requestsPerMinute = randomNonNegativeInt(); + var similarity = randomSimilarityMeasure(); + var dimensions = randomNonNegativeInt(); + var dimensionsSetByUser = randomBoolean(); + var maxInputTokens = randomNonNegativeInt(); + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var multimodalModel = taskType == EMBEDDING && randomBoolean(); + var inputType = InputTypeTests.randomRequestType(); + var lateChunking = randomBoolean(); + var chunkingSettings = createRandomChunkingSettings(); + + var serviceSettingsMap = getMapOfCommonEmbeddingSettings( + modelName, + similarity, + dimensions, + dimensionsSetByUser, + maxInputTokens, + embeddingType, + requestsPerMinute ); - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + if (taskType == EMBEDDING) { + serviceSettingsMap.put(MULTIMODAL_MODEL, multimodalModel); + } - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(inputType, lateChunking), + chunkingSettings.asMap(), + null + ); - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertNull(embeddingsModel.getSecretSettings()); + var model = service.parsePersistedConfig("id", taskType, persistedConfig.config()); + + assertEmbeddingModelSettings( + model, + modelName, + new RateLimitSettings(requestsPerMinute), + similarity, + dimensions, + dimensionsSetByUser, + maxInputTokens, + embeddingType, + multimodalModel, + new JinaAIEmbeddingsTaskSettings(inputType, lateChunking), + chunkingSettings, + "" + ); } } - public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + public void testParsePersistedConfig_createsRerankModel() throws IOException { try (var service = createJinaAIService()) { + var modelName = randomAlphanumericOfLength(8); + var requestsPerMinute = randomNonNegativeInt(); + var topN = randomNonNegativeInt(); + var returnDocuments = randomBoolean(); + var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of() + JinaAIRerankServiceSettingsTests.getServiceSettingsMap(modelName, requestsPerMinute), + JinaAIRerankTaskSettingsTests.getTaskSettingsMap(topN, returnDocuments), + null ); - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig("id", TaskType.RERANK, persistedConfig.config()); - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertNull(embeddingsModel.getSecretSettings()); + assertRerankModelSettings( + model, + modelName, + new RateLimitSettings(requestsPerMinute), + "", + new JinaAIRerankTaskSettings(topN, returnDocuments) + ); } } - public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException { + public void testParsePersistedConfig_ThrowsErrorWithUnsupportedTaskType() throws IOException { try (var service = createJinaAIService()) { - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model_old", JinaAIEmbeddingType.FLOAT), - Map.of() + var unsupportedTaskType = randomValueOtherThanMany( + t -> service.supportedTaskTypes().contains(t), + () -> randomFrom(TaskType.values()) ); + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model_old", null)); var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) + () -> service.parsePersistedConfig("id", unsupportedTaskType, persistedConfig.config()) ); assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [jinaai] service")); - assertThat(thrownException.getMessage(), containsString("The [jinaai] service does not support task type [sparse_embedding]")); + assertThat( + thrownException.getMessage(), + containsString(Strings.format("The [jinaai] service does not support task type [%s]", unsupportedTaskType)) + ); } } public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createJinaAIService()) { - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of() - ); + String modelName = "modelName"; + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(modelName, null)); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); - assertNull(embeddingsModel.getSecretSettings()); + assertParsePersistedConfigMinimalSettings(service, persistedConfig, modelName); } } public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createJinaAIService()) { - var serviceSettingsMap = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT); + String modelName = "modelName"; + var serviceSettingsMap = getServiceSettingsMap(modelName, null); serviceSettingsMap.put("extra_key", "value"); - var persistedConfig = getPersistedConfigMap( - serviceSettingsMap, - JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH) - ); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var persistedConfig = getPersistedConfigMap(serviceSettingsMap); - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.SEARCH))); - assertNull(embeddingsModel.getSecretSettings()); + assertParsePersistedConfigMinimalSettings(service, persistedConfig, modelName); } } public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createJinaAIService()) { - var taskSettingsMap = JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST); - taskSettingsMap.put("extra_key", "value"); + String modelName = "modelName"; + var taskSettingsMap = new HashMap(Map.of("extra_key", "value")); - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - taskSettingsMap - ); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(modelName, null), taskSettingsMap); - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); - assertNull(embeddingsModel.getSecretSettings()); + assertParsePersistedConfigMinimalSettings(service, persistedConfig, modelName); } } - public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException { - var sender = createMockSender(); - - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - - var mockModel = getInvalidModel("model_id", "service_name"); + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInChunkingSettings() throws IOException { + try (var service = createJinaAIService()) { + String modelName = "modelName"; - try (var service = new JinaAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - PlainActionFuture listener = new PlainActionFuture<>(); - service.infer( - mockModel, - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap(modelName, null), + Map.of(), + Map.of(ChunkingSettingsOptions.STRATEGY.toString(), ChunkingStrategy.NONE.toString(), "extra_key", "value"), + null ); - var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat( - thrownException.getMessage(), - is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") - ); + var model = service.parsePersistedConfig("id", randomEmbeddingTaskType(), persistedConfig.config()); - verify(factory, times(1)).createSender(); - verify(sender, times(1)).startAsynchronously(any()); + assertThat(model.getServiceSettings().modelId(), is(modelName)); + assertThat(model.apiKey().toString(), is("")); } - - verify(sender, times(1)).close(); - verifyNoMoreInteractions(factory); - verifyNoMoreInteractions(sender); } public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { - testUpdateModelWithEmbeddingDetails_Successful(null); + testUpdateModelWithEmbeddingDetails_Successful(null, 128); } public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException { - testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values())); + testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()), 128); + } + + public void testUpdateModelWithEmbeddingDetails_NullDimensionsInOriginalModel() throws IOException { + testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()), null); } - private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { + public void testUpdateModelWithEmbeddingDetails_NullSimilarityAndDimensionsInOriginalModel() throws IOException { + testUpdateModelWithEmbeddingDetails_Successful(null, null); + } + + private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure, Integer dimensions) + throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { @@ -745,13 +858,15 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si randomAlphaOfLength(10), RateLimitSettingsTests.createRandom(), similarityMeasure, - randomNonNegativeInt(), + dimensions, randomNonNegativeInt(), embeddingType, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, null, randomAlphaOfLength(10), - false + false, + randomEmbeddingTaskType(), + randomBoolean() ); Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); @@ -764,120 +879,48 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si } } - public void testInfer_Embedding_UnauthorisedResponse() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - - String responseJson = """ - { - "detail": "Unauthorized" - } - """; - webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); - - var model = JinaAIEmbeddingsModelTests.createModel(getUrl(webServer), "model", "secret"); - PlainActionFuture listener = new PlainActionFuture<>(); - service.infer( - model, - null, - null, - null, - List.of("abc"), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - - var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); - assertThat(error.getMessage(), containsString("Error message: [Unauthorized]")); - assertThat(webServer.requests(), hasSize(1)); - } - } - - public void testInfer_Rerank_UnauthorisedResponse() throws IOException { + public void testUpdateModelWithEmbeddingDetails_returnsExistingModelIfSettingsUnchanged() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - - String responseJson = """ - { - "detail": "Unauthorized" - } - """; - webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); - - var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "model", 1024, false); - PlainActionFuture listener = new PlainActionFuture<>(); - service.infer( - model, - "query", - null, - null, - List.of("candidate1", "candidate2"), - false, - new HashMap<>(), - null, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - - var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); - assertThat(error.getMessage(), containsString("Error message: [Unauthorized]")); - assertThat(webServer.requests(), hasSize(1)); - } - } - - public void testInfer_TextEmbedding_Get_Response_Ingest() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - - String responseJson = """ - { - "model": "jina-clip-v2", - "object": "list", - "usage": { - "total_tokens": 5, - "prompt_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - String apiKey = "apiKey"; - int dimensions = 1024; - String modelName = "jina-clip-v2"; - var model = JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - modelName, - JinaAIEmbeddingType.FLOAT, + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var similarityMeasure = randomSimilarityMeasure(); + var dimensions = randomNonNegativeInt(); + var model = JinaAIEmbeddingsModelTests.createModel( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + RateLimitSettingsTests.createRandom(), + similarityMeasure, + dimensions, + randomNonNegativeInt(), + randomFrom(JinaAIEmbeddingType.values()), JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - apiKey, - dimensions + null, + randomAlphaOfLength(10), + false, + randomEmbeddingTaskType(), + randomBoolean() ); + + assertThat(service.updateModelWithEmbeddingDetails(model, dimensions), sameInstance(model)); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException { + var sender = createMockSender(); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name"); + + try (var service = new JinaAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); - List input = List.of("abc"); service.infer( - model, + mockModel, null, null, null, - input, + List.of(""), false, new HashMap<>(), InputType.INGEST, @@ -885,179 +928,119 @@ public void testInfer_TextEmbedding_Get_Response_Ingest() throws IOException { listener ); - var result = listener.actionGet(TIMEOUT); - - assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); - - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer " + apiKey)); - - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); assertThat( - requestMap, - is( - Map.of( - "input", - input, - "model", - modelName, - "task", - "retrieval.passage", - "embedding_type", - "float", - "dimensions", - dimensions - ) - ) + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).startAsynchronously(any()); } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); } - public void testInfer_TextEmbedding_Get_Response_Search() throws IOException { + public void testInfer_TextEmbedding_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { - "model": "jina-clip-v2", - "object": "list", - "usage": { - "total_tokens": 5, - "prompt_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] + "detail": "Unauthorized" } """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); - String apiKey = "apiKey"; - int dimensions = 1024; - String modelName = "jina-clip-v2"; - var model = JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - modelName, - JinaAIEmbeddingType.FLOAT, - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - apiKey, - dimensions - ); + var model = JinaAIEmbeddingsModelTests.createTextEmbeddingModel(getUrl(webServer), "model", "secret"); PlainActionFuture listener = new PlainActionFuture<>(); - List input = List.of("abc"); service.infer( model, null, null, null, - input, + List.of("abc"), false, new HashMap<>(), - InputType.SEARCH, + InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener ); - var result = listener.actionGet(TIMEOUT); - - assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); - + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + assertThat(error.getMessage(), containsString("Error message: [Unauthorized]")); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer " + apiKey)); - - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat( - requestMap, - is( - Map.of( - "input", - input, - "model", - modelName, - "task", - "retrieval.query", - "embedding_type", - "float", - "dimensions", - dimensions - ) - ) - ); } } - public void testInfer_TextEmbedding_Get_Response_clustering() throws IOException { + public void testInfer_Rerank_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ - {"model":"jina-clip-v2","object":"list","usage":{"total_tokens":5,"prompt_tokens":5}, - "data":[{"object":"embedding","index":0,"embedding":[0.123, -0.123]}]} + { + "detail": "Unauthorized" + } """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); - String apiKey = "apiKey"; - int dimensions = 1024; - String modelName = "jina-clip-v2"; - var model = JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - modelName, - JinaAIEmbeddingType.FLOAT, - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - apiKey, - dimensions - ); + var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "model", 1024, false); PlainActionFuture listener = new PlainActionFuture<>(); - List input = List.of("abc"); service.infer( model, + "query", null, null, - null, - input, + List.of("candidate1", "candidate2"), false, new HashMap<>(), - InputType.CLUSTERING, + null, InferenceAction.Request.DEFAULT_TIMEOUT, listener ); - var result = listener.actionGet(TIMEOUT); + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + assertThat(error.getMessage(), containsString("Error message: [Unauthorized]")); + assertThat(webServer.requests(), hasSize(1)); + } + } - assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); + public void testInfer_TextEmbedding_Get_Response_Ingest() throws IOException { + testInfer_TextEmbedding_Get_Response(randomFrom(InputType.INGEST, InputType.INTERNAL_INGEST), "retrieval.passage"); + } - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer " + apiKey)); + public void testInfer_TextEmbedding_Get_Response_Search() throws IOException { + testInfer_TextEmbedding_Get_Response(randomFrom(InputType.SEARCH, InputType.INTERNAL_SEARCH), "retrieval.query"); + } - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat( - requestMap, - is(Map.of("input", input, "model", modelName, "task", "separation", "embedding_type", "float", "dimensions", dimensions)) - ); - } + public void testInfer_TextEmbedding_Get_Response_clustering() throws IOException { + testInfer_TextEmbedding_Get_Response(InputType.CLUSTERING, "separation"); + } + + public void testInfer_TextEmbedding_Get_Response_classification() throws IOException { + testInfer_TextEmbedding_Get_Response(InputType.CLASSIFICATION, "classification"); + } + + public void testInfer_TextEmbedding_Get_Response_unspecified() throws IOException { + testInfer_TextEmbedding_Get_Response(InputType.UNSPECIFIED, null); } public void testInfer_TextEmbedding_Get_Response_NullInputType() throws IOException { + testInfer_TextEmbedding_Get_Response(null, null); + } + + private void testInfer_TextEmbedding_Get_Response(InputType inputType, String expectedJinaTask) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - String responseJson = """ + var responseJson = """ { "model": "jina-clip-v2", "object": "list", @@ -1077,24 +1060,38 @@ public void testInfer_TextEmbedding_Get_Response_NullInputType() throws IOExcept ] } """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - String apiKey = "apiKey"; - int dimensions = 1024; String modelName = "jina-clip-v2"; + int dimensions = 1024; + String apiKey = "apiKey"; var model = JinaAIEmbeddingsModelTests.createModel( getUrl(webServer), modelName, JinaAIEmbeddingType.FLOAT, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, apiKey, - dimensions + dimensions, + TEXT_EMBEDDING, + null ); PlainActionFuture listener = new PlainActionFuture<>(); List input = List.of("abc"); - service.infer(model, null, null, null, input, false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + service.infer( + model, + null, + null, + null, + input, + false, + new HashMap<>(), + inputType, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(TEST_REQUEST_TIMEOUT); assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); @@ -1103,8 +1100,14 @@ public void testInfer_TextEmbedding_Get_Response_NullInputType() throws IOExcept assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer " + apiKey)); + Map expectedRequestMap = new HashMap<>( + Map.of("input", input, "model", modelName, "embedding_type", "float", "dimensions", dimensions) + ); + if (expectedJinaTask != null) { + expectedRequestMap.put("task", expectedJinaTask); + } var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat(requestMap, is(Map.of("input", input, "model", modelName, "embedding_type", "float", "dimensions", dimensions))); + assertThat(requestMap, is(expectedRequestMap)); } } @@ -1150,7 +1153,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(TEST_REQUEST_TIMEOUT); var resultAsMap = result.asMap(); assertThat( resultAsMap, @@ -1232,7 +1235,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(TEST_REQUEST_TIMEOUT); var resultAsMap = result.asMap(); assertThat( resultAsMap, @@ -1326,7 +1329,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(TEST_REQUEST_TIMEOUT); var resultAsMap = result.asMap(); assertThat( resultAsMap, @@ -1406,7 +1409,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(TEST_REQUEST_TIMEOUT); var resultAsMap = result.asMap(); assertThat( resultAsMap, @@ -1426,188 +1429,152 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat( - requestMap, - is( - Map.of( - "query", - "query", - "documents", - List.of("candidate1", "candidate2", "candidate3", "candidate4"), - "model", - "model", - "return_documents", - true, - "top_n", - 3 - ) - ) - ); - - } - - } - - public void testInfer_TextEmbedding_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() - throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - - String responseJson = """ - { - "model": "jina-clip-v2", - "object": "list", - "usage": { - "total_tokens": 5, - "prompt_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - String apiKey = "apiKey"; - int dimensions = 1024; - String modelName = "jina-clip-v2"; - - var model = JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - modelName, - JinaAIEmbeddingType.FLOAT, - new JinaAIEmbeddingsTaskSettings(null, null), - apiKey, - dimensions - ); - PlainActionFuture listener = new PlainActionFuture<>(); - List input = List.of("abc"); - service.infer( - model, - null, - null, - null, - input, - false, - new HashMap<>(), - InputType.UNSPECIFIED, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - - var result = listener.actionGet(TIMEOUT); - - assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer " + apiKey)); - - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat(requestMap, is(Map.of("input", input, "model", modelName, "embedding_type", "float", "dimensions", dimensions))); + assertThat( + requestMap, + is( + Map.of( + "query", + "query", + "documents", + List.of("candidate1", "candidate2", "candidate3", "candidate4"), + "model", + "model", + "return_documents", + true, + "top_n", + 3 + ) + ) + ); + } } - public void test_Embedding_ChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException { + public void test_TextEmbeddingModel_ChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException { var model = JinaAIEmbeddingsModelTests.createModel( getUrl(webServer), "jina-clip-v2", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), createRandomChunkingSettings(), - "secret" + "secret", + TEXT_EMBEDDING ); - test_Embedding_ChunkedInfer_BatchesCalls(model); + test_embedding_chunkedInfer_batchesCalls(model, model.getTaskSettings().getLateChunking(), false); } - public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOException { + public void test_TextEmbeddingModel_ChunkedInfer_ChunkingSettingsNotSet() throws IOException { var model = JinaAIEmbeddingsModelTests.createModel( getUrl(webServer), "jina-clip-v2", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), - "secret" + "secret", + TEXT_EMBEDDING ); - test_Embedding_ChunkedInfer_BatchesCalls(model); + test_embedding_chunkedInfer_batchesCalls(model, model.getTaskSettings().getLateChunking(), false); } - public void test_Embedding_ChunkedInfer_LateChunkingEnabled() throws IOException { + public void test_TextEmbeddingModel_ChunkedInfer_LateChunkingEnabled() throws IOException { var model = JinaAIEmbeddingsModelTests.createModel( getUrl(webServer), "jina-clip-v2", new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), - "secret" + "secret", + TEXT_EMBEDDING ); - test_Embedding_ChunkedInfer_BatchesCalls(model); + test_embedding_chunkedInfer_batchesCalls(model, model.getTaskSettings().getLateChunking(), false); } - public void test_Embedding_ChunkedInfer_LateChunkingDisabled() throws IOException { + public void test_TextEmbeddingModel_ChunkedInfer_LateChunkingDisabled() throws IOException { var model = JinaAIEmbeddingsModelTests.createModel( getUrl(webServer), "jina-clip-v2", new JinaAIEmbeddingsTaskSettings(InputType.INGEST, false), - "secret" + "secret", + TEXT_EMBEDDING ); - test_Embedding_ChunkedInfer_BatchesCalls(model); + test_embedding_chunkedInfer_batchesCalls(model, model.getTaskSettings().getLateChunking(), false); } - public void test_Embedding_ChunkedInfer_noInputs() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var model = JinaAIEmbeddingsModelTests.createModel(getUrl(webServer), "jina-clip-v2", "secret"); + public void test_embeddingModel_chunkedInfer_batchesCallsWhenLateChunkingEnabled() throws IOException { + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "jina-clip-v2", + new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), + "secret", + EMBEDDING + ); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - PlainActionFuture> listener = new PlainActionFuture<>(); - service.chunkedInfer( - model, - null, - List.of(), - new HashMap<>(), - InputType.UNSPECIFIED, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + test_embedding_chunkedInfer_batchesCalls(model, model.getTaskSettings().getLateChunking(), false); + } - var results = listener.actionGet(TIMEOUT); - assertThat(results, empty()); - assertThat(webServer.requests(), empty()); - } + public void test_embeddingModel_chunkedInfer_batchesCallsWhenLateChunkingEnabled_inputContainsNonTextInput() throws IOException { + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "jina-clip-v2", + new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), + "secret", + EMBEDDING + ); + + test_embedding_chunkedInfer_batchesCalls(model, false, true); + } + + public void test_embeddingModel_chunkedInfer_batchesCallsWhenLateChunkingDisabled_inputContainsNonTextInput() throws IOException { + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "jina-clip-v2", + new JinaAIEmbeddingsTaskSettings(InputType.INGEST, false), + "secret", + EMBEDDING + ); + + test_embedding_chunkedInfer_batchesCalls(model, false, true); } - private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel model) throws IOException { + private void test_embedding_chunkedInfer_batchesCalls( + JinaAIEmbeddingsModel model, + Boolean expectMultipleResponses, + boolean nonTextInput + ) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - queueResponsesForChunkedInfer(model.getTaskSettings().getLateChunking()); + queueResponsesForChunkedInfer(expectMultipleResponses); PlainActionFuture> listener = new PlainActionFuture<>(); - // 2 input + // 2 inputs + String firstInput = "first_input"; + String secondInput = "second_input"; + List input; + if (nonTextInput) { + input = List.of( + new ChunkInferenceInput(firstInput), + new ChunkInferenceInput(new InferenceStringGroup(new InferenceString(IMAGE, BASE64, secondInput)), null) + ); + } else { + input = List.of(new ChunkInferenceInput(firstInput), new ChunkInferenceInput(secondInput)); + } service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + input, new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, listener ); - var results = listener.actionGet(TIMEOUT); + var results = listener.actionGet(TEST_REQUEST_TIMEOUT); assertThat(results, hasSize(2)); { assertThat(results.getFirst(), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.getFirst(); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().getFirst().offset()); + assertEquals(new ChunkedInference.TextOffset(0, firstInput.length()), floatResult.chunks().getFirst().offset()); assertThat(floatResult.chunks().getFirst().embedding(), Matchers.instanceOf(EmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, @@ -1619,7 +1586,7 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().getFirst().offset()); + assertEquals(new ChunkedInference.TextOffset(0, secondInput.length()), floatResult.chunks().getFirst().offset()); assertThat(floatResult.chunks().getFirst().embedding(), Matchers.instanceOf(EmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.223f, -0.223f }, @@ -1707,6 +1674,256 @@ private void queueResponsesForChunkedInfer(Boolean lateChunking) { } } + public void test_ChunkedInfer_noInputs() throws IOException { + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "jina-clip-v2", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "secret", + randomFrom(TEXT_EMBEDDING, EMBEDDING) + ); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TEST_REQUEST_TIMEOUT); + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } + + public void testEmbeddingInfer_returnsError_withNonJinaModel() throws IOException { + String modelName = "model_id"; + String serviceName = "service_name"; + var mockModel = getInvalidModel(modelName, serviceName); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.embeddingInfer( + mockModel, + new EmbeddingRequest(List.of(), InputType.UNSPECIFIED, Map.of()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat( + thrownException.getMessage(), + is( + Strings.format( + "The internal model was invalid, please delete the service [%s] with id [%s] and add it again.", + serviceName, + modelName + ) + ) + ); + assertThat(thrownException.status(), is(RestStatus.INTERNAL_SERVER_ERROR)); + } + } + + public void testEmbeddingInfer_returnsError_withRerankModel() throws IOException { + var model = JinaAIRerankModelTests.createModel("modelName"); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.embeddingInfer( + model, + new EmbeddingRequest(List.of(), InputType.UNSPECIFIED, Map.of()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [jinaai] with id [id] and add it again.") + ); + assertThat(thrownException.status(), is(RestStatus.INTERNAL_SERVER_ERROR)); + } + } + + public void testEmbeddingInfer_returnsError_nonMultimodalModel_withNonTextInput() throws IOException { + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "modelName", + JinaAIEmbeddingType.FLOAT, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "apiKey", + 128, + EMBEDDING, + false + ); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture listener = new PlainActionFuture<>(); + var inputs = List.of( + new InferenceStringGroup("first_input"), + new InferenceStringGroup(new InferenceString(IMAGE, BASE64, "second_input")) + ); + service.embeddingInfer( + model, + new EmbeddingRequest(inputs, InputType.UNSPECIFIED, Map.of()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat(thrownException.getMessage(), is("Non-text input provided for text-only model")); + assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST)); + } + } + + public void testEmbeddingInfer_UnauthorisedResponse() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + + String responseJson = """ + { + "detail": "Unauthorized" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var model = JinaAIEmbeddingsModelTests.createEmbeddingModel(getUrl(webServer), "model", "secret"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.embeddingInfer( + model, + new EmbeddingRequest(List.of(), InputType.UNSPECIFIED, Map.of()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + assertThat(error.getMessage(), containsString("Error message: [Unauthorized]")); + assertThat(webServer.requests(), hasSize(1)); + } + } + + public void testEmbeddingInfer_Ingest() throws IOException { + testEmbeddingInfer(randomFrom(InputType.INGEST, InputType.INTERNAL_INGEST), "retrieval.passage"); + } + + public void testEmbeddingInfer_Search() throws IOException { + testEmbeddingInfer(randomFrom(InputType.SEARCH, InputType.INTERNAL_SEARCH), "retrieval.query"); + } + + public void testEmbeddingInfer_clustering() throws IOException { + testEmbeddingInfer(InputType.CLUSTERING, "separation"); + } + + public void testEmbeddingInfer_classification() throws IOException { + testEmbeddingInfer(InputType.CLASSIFICATION, "classification"); + } + + public void testEmbeddingInfer_nullInputType() throws IOException { + testEmbeddingInfer(null, null); + } + + public void testEmbeddingInfer_unspecifiedInputType() throws IOException { + testEmbeddingInfer(InputType.UNSPECIFIED, null); + } + + private void testEmbeddingInfer(InputType inputType, String expectedJinaTask) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + + var responseJson = """ + { + "model": "jina-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + String modelName = "jina-clip-v2"; + int dimensions = 1024; + String apiKey = "apiKey"; + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + modelName, + JinaAIEmbeddingType.FLOAT, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + apiKey, + dimensions, + EMBEDDING, + true + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + var inputs = List.of( + new InferenceStringGroup("first_input"), + new InferenceStringGroup(new InferenceString(IMAGE, BASE64, "second_input")) + ); + service.embeddingInfer( + model, + new EmbeddingRequest(inputs, inputType, Map.of()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TEST_REQUEST_TIMEOUT); + + assertEquals( + GenericDenseEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), + result.asMap() + ); + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer " + apiKey)); + + Map expectedRequestMap = new HashMap<>( + Map.of( + "input", + List.of(Map.of("text", "first_input"), Map.of("image", "second_input")), + "model", + modelName, + "embedding_type", + "float", + "dimensions", + dimensions + ) + ); + if (expectedJinaTask != null) { + expectedRequestMap.put("task", expectedJinaTask); + } + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + assertThat(requestMap, is(expectedRequestMap)); + } + } + public void testDefaultSimilarity_BinaryEmbedding() { assertEquals(SimilarityMeasure.L2_NORM, JinaAIService.defaultSimilarity(JinaAIEmbeddingType.BINARY)); assertEquals(SimilarityMeasure.L2_NORM, JinaAIService.defaultSimilarity(JinaAIEmbeddingType.BIT)); @@ -1724,7 +1941,7 @@ public void testGetConfiguration() throws Exception { { "service": "jinaai", "name": "Jina AI", - "task_types": ["text_embedding", "rerank"], + "task_types": ["text_embedding", "rerank", "embedding"], "configurations": { "api_key": { "description": "API Key for the provider you're connecting to.", @@ -1733,7 +1950,7 @@ public void testGetConfiguration() throws Exception { "sensitive": true, "updatable": true, "type": "str", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "embedding"] }, "dimensions": { "description": "The number of dimensions the resulting embeddings should have. For more information refer to https://api.jina.ai/docs#tag/embeddings/operation/create_embedding_v1_embeddings_post.", @@ -1742,7 +1959,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "embedding"] }, "embedding_type": { "description": "The type of embedding to return. One of [float, bit, binary]. bit and binary are equivalent and are encoded as bytes with signed int8 precision.", @@ -1752,7 +1969,7 @@ public void testGetConfiguration() throws Exception { "updatable": false, "default_value": "float", "type": "str", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "embedding"] }, "similarity": { "description": "The similarity measure. One of [cosine, dot_product, l2_norm]. For float embeddings, the default similarity is dot_product. For bit and binary embeddings, the default similarity is l2_norm.", @@ -1761,7 +1978,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "embedding"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -1770,7 +1987,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "embedding"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -1779,7 +1996,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "embedding"] } } } @@ -1828,9 +2045,7 @@ private Map getRequestConfigMap( builtServiceSettings.putAll(serviceSettings); builtServiceSettings.putAll(secretSettings); - return new HashMap<>( - Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) - ); + return new HashMap<>(Map.of(SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)); } private Map getRequestConfigMap(Map serviceSettings, Map secretSettings) { @@ -1854,4 +2069,113 @@ public InferenceService createInferenceService() { protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { assertThat(rerankingInferenceService.rerankerWindowSize("any model"), is(5500)); } + + private static void assertEmbeddingModelSettings( + Model model, + String modelName, + RateLimitSettings rateLimitSettings, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + JinaAIEmbeddingType embeddingType, + boolean multimodalModel, + JinaAIEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, + String apiKey + ) { + assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + assertCommonModelSettings( + embeddingsModel, + DEFAULT_EMBEDDING_URL, + modelName, + rateLimitSettings, + similarity, + dimensions, + dimensionsSetByUser, + chunkingSettings, + apiKey + ); + + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(maxInputTokens)); + assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(embeddingType)); + assertThat(embeddingsModel.getServiceSettings().isMultimodal(), is(multimodalModel)); + + assertThat(embeddingsModel.getTaskSettings(), is(taskSettings)); + } + + private static void assertRerankModelSettings( + Model model, + String modelName, + RateLimitSettings rateLimitSettings, + String apiKey, + JinaAIRerankTaskSettings taskSettings + ) { + assertThat(model, instanceOf(JinaAIRerankModel.class)); + + var rerankModel = (JinaAIRerankModel) model; + assertCommonModelSettings(rerankModel, DEFAULT_RERANK_URL, modelName, rateLimitSettings, null, null, null, null, apiKey); + + assertThat(rerankModel.getTaskSettings(), is(taskSettings)); + } + + private static void assertCommonModelSettings( + T model, + String url, + String modelName, + RateLimitSettings rateLimitSettings, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Boolean dimensionsSetByUser, + @Nullable ChunkingSettings chunkingSettings, + String apiKey + ) { + assertThat(model.uri().toString(), is(url)); + assertThat(model.getServiceSettings().modelId(), is(modelName)); + assertThat(model.rateLimitServiceSettings().rateLimitSettings(), is(rateLimitSettings)); + assertThat(model.getServiceSettings().similarity(), is(similarity)); + assertThat(model.getServiceSettings().dimensions(), is(dimensions)); + assertThat(model.getServiceSettings().dimensionsSetByUser(), is(dimensionsSetByUser)); + + assertThat(model.getConfigurations().getChunkingSettings(), is(chunkingSettings)); + + assertThat(model.apiKey().toString(), is(apiKey)); + } + + private static ActionListener getModelListenerForStatusException(String expectedMessage) { + return ActionListener.wrap((model) -> fail("Model parsing should have failed"), e -> { + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat(e.getMessage(), is(expectedMessage)); + }); + } + + private static void assertParsePersistedConfigWithSecretsMinimalSettings( + JinaAIService service, + Utils.PersistedConfig persistedConfig, + String modelName, + String apiKey + ) { + var model = service.parsePersistedConfigWithSecrets( + "id", + randomFrom(service.supportedTaskTypes()), + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model.getServiceSettings().modelId(), is(modelName)); + assertThat(model.apiKey().toString(), is(apiKey)); + } + + private static void assertParsePersistedConfigMinimalSettings( + JinaAIService service, + Utils.PersistedConfig persistedConfig, + String modelName + ) { + var model = service.parsePersistedConfig("id", randomFrom(service.supportedTaskTypes()), persistedConfig.config()); + + assertThat(model.getServiceSettings().modelId(), is(modelName)); + assertThat(model.apiKey().toString(), is("")); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..1d5a096af2ed7 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettingsTests.java @@ -0,0 +1,280 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; +import static org.elasticsearch.xpack.inference.TaskTypeTests.randomEmbeddingTaskType; +import static org.elasticsearch.xpack.inference.Utils.randomSimilarityMeasure; +import static org.elasticsearch.xpack.inference.services.ConfigurationParseContext.PERSISTENT; +import static org.elasticsearch.xpack.inference.services.ConfigurationParseContext.REQUEST; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS_SET_BY_USER; +import static org.elasticsearch.xpack.inference.services.ServiceFields.EMBEDDING_TYPE; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MULTIMODAL_MODEL; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.updateEmbeddingDetails; +import static org.hamcrest.Matchers.anEmptyMap; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.sameInstance; + +public class BaseJinaAIEmbeddingsServiceSettingsTests extends ESTestCase { + + public void testFromMap_parsesAllFields_textEmbedding_requestContext() { + testFromMap_parsesAllFields(TEXT_EMBEDDING, REQUEST, randomNonNegativeInt()); + } + + public void testFromMap_parsesAllFields_embedding_requestContext() { + testFromMap_parsesAllFields(TaskType.EMBEDDING, REQUEST, randomNonNegativeInt()); + } + + public void testFromMap_parsesAllFields_textEmbedding_persistentContext() { + testFromMap_parsesAllFields(TEXT_EMBEDDING, PERSISTENT, randomNonNegativeInt()); + } + + public void testFromMap_parsesAllFields_embedding_persistentContext() { + testFromMap_parsesAllFields(TaskType.EMBEDDING, PERSISTENT, randomNonNegativeInt()); + } + + public void testFromMap_parsesAllFields_textEmbedding_requestContext_dimensionsNotSet() { + testFromMap_parsesAllFields(TEXT_EMBEDDING, REQUEST, null); + } + + public void testFromMap_parsesAllFields_embedding_requestContext_dimensionsNotSet() { + testFromMap_parsesAllFields(TaskType.EMBEDDING, REQUEST, null); + } + + private void testFromMap_parsesAllFields(TaskType taskType, ConfigurationParseContext parseContext, Integer dimensions) { + var similarity = randomSimilarityMeasure(); + var maxInputTokens = randomNonNegativeInt(); + var model = randomAlphanumericOfLength(8); + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var requestsPerMinute = randomNonNegativeInt(); + var settingsMap = getMapOfCommonEmbeddingSettings( + model, + similarity, + dimensions, + null, + maxInputTokens, + embeddingType, + requestsPerMinute + ); + + var dimensionsSetByUser = dimensions != null; + if (parseContext == PERSISTENT) { + dimensionsSetByUser = randomBoolean(); + settingsMap.put(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + } + + var multimodalModel = false; + if (taskType == TaskType.EMBEDDING) { + multimodalModel = randomBoolean(); + settingsMap.put(MULTIMODAL_MODEL, multimodalModel); + } + + var serviceSettings = BaseJinaAIEmbeddingsServiceSettings.fromMap(settingsMap, taskType, parseContext); + + assertThat(settingsMap, anEmptyMap()); + + assertServiceSettings( + serviceSettings, + taskType, + model, + requestsPerMinute, + similarity, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser, + multimodalModel + ); + } + + public void testFromMap_doesNotRemoveMultimodalModelField_whenTaskTypeIsTextEmbedding() { + var settingsMap = getMapOfMinimalEmbeddingSettings(randomAlphanumericOfLength(8)); + + settingsMap.put(MULTIMODAL_MODEL, randomBoolean()); + + var settings = BaseJinaAIEmbeddingsServiceSettings.fromMap( + settingsMap, + TEXT_EMBEDDING, + randomFrom(ConfigurationParseContext.values()) + ); + + assertThat(settingsMap.get(MULTIMODAL_MODEL), notNullValue()); + assertThat(settings.isMultimodal(), is(false)); + } + + private static void assertServiceSettings( + BaseJinaAIEmbeddingsServiceSettings serviceSettings, + TaskType taskType, + String model, + Integer requestsPerMinute, + SimilarityMeasure similarity, + Integer dimensions, + Integer maxInputTokens, + JinaAIEmbeddingType embeddingType, + boolean dimensionsSetByUser, + Boolean multimodalModel + ) { + BaseJinaAIEmbeddingsServiceSettings expectedSettings; + if (taskType == TEXT_EMBEDDING) { + expectedSettings = new JinaAITextEmbeddingServiceSettings( + new JinaAIServiceSettings(model, new RateLimitSettings(requestsPerMinute)), + similarity, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser + ); + } else if (taskType == TaskType.EMBEDDING) { + expectedSettings = new JinaAIEmbeddingServiceSettings( + new JinaAIServiceSettings(model, new RateLimitSettings(requestsPerMinute)), + similarity, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser, + multimodalModel + ); + } else { + throw new IllegalArgumentException("Invalid taskType " + taskType); + } + + assertThat(serviceSettings, is(expectedSettings)); + } + + public void testFromMap_withInvalidSimilarity_throwsError() { + var similarity = "by_size"; + var thrownException = expectThrows( + ValidationException.class, + () -> BaseJinaAIEmbeddingsServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model", SIMILARITY, similarity)), + randomEmbeddingTaskType(), + randomFrom(ConfigurationParseContext.values()) + ) + ); + + assertThat( + thrownException.getMessage(), + is( + "Validation Failed: 1: [service_settings] Invalid value [by_size] received. [similarity] " + + "must be one of [cosine, dot_product, l2_norm];" + ) + ); + } + + public void testFromMap_nonPositiveDimensions_throwsError() { + var dimensions = randomIntBetween(-5, 0); + var thrownException = expectThrows( + ValidationException.class, + () -> BaseJinaAIEmbeddingsServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model", DIMENSIONS, dimensions)), + randomEmbeddingTaskType(), + randomFrom(ConfigurationParseContext.values()) + ) + ); + + assertThat( + thrownException.getMessage(), + is( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [%s] must be a positive integer;", + dimensions, + DIMENSIONS + ) + ) + ); + } + + public void testFromMap_withInvalidEmbeddingType_throwsError() { + var embeddingType = "bad_value"; + var thrownException = expectThrows( + ValidationException.class, + () -> BaseJinaAIEmbeddingsServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model", EMBEDDING_TYPE, embeddingType)), + randomEmbeddingTaskType(), + randomFrom(ConfigurationParseContext.values()) + ) + ); + + assertThat( + thrownException.getMessage(), + is( + "Validation Failed: 1: [service_settings] Invalid value [bad_value] received. [embedding_type] " + + "must be one of [binary, bit, float];" + ) + ); + } + + public void testUpdateEmbeddingDetails_returnsSameInstance_whenEmbeddingSizeAndSimilarityAreSame() { + var settings = randomBoolean() + ? JinaAIEmbeddingServiceSettingsTests.createRandomWithNoNullValues() + : JinaAITextEmbeddingServiceSettingsTests.createRandomWithNoNullValues(); + + var updatedSettings = updateEmbeddingDetails(settings, settings.dimensions(), settings.similarity()); + + assertThat(updatedSettings, sameInstance(settings)); + } + + /** + * Returns a map containing only the fields that are required by both {@link JinaAIEmbeddingServiceSettings} and + * {@link JinaAITextEmbeddingServiceSettings} + */ + public static Map getMapOfMinimalEmbeddingSettings(String modelName) { + return getMapOfCommonEmbeddingSettings(modelName, null, null, null, null, null, null); + } + + /** + * Returns a map containing all fields that are used by both {@link JinaAIEmbeddingServiceSettings} and + * {@link JinaAITextEmbeddingServiceSettings} + */ + public static Map getMapOfCommonEmbeddingSettings( + String modelName, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + @Nullable JinaAIEmbeddingType embeddingType, + @Nullable Integer requestsPerMinute + ) { + var map = JinaAIServiceSettingsTests.getServiceSettingsMap(modelName, requestsPerMinute); + if (similarity != null) { + map.put(SIMILARITY, similarity.toString()); + } + if (dimensions != null) { + map.put(DIMENSIONS, dimensions); + } + if (dimensionsSetByUser != null) { + map.put(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + } + if (maxInputTokens != null) { + map.put(MAX_INPUT_TOKENS, maxInputTokens); + } + if (embeddingType != null) { + map.put(EMBEDDING_TYPE, embeddingType.toString()); + } + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettingsTests.java new file mode 100644 index 0000000000000..d152fb3d746cc --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettingsTests.java @@ -0,0 +1,359 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace; +import static org.elasticsearch.inference.EmbeddingRequest.JINA_AI_EMBEDDING_TASK_ADDED; +import static org.elasticsearch.xpack.inference.Utils.randomSimilarityMeasure; +import static org.elasticsearch.xpack.inference.services.ServiceFields.EMBEDDING_TYPE; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MULTIMODAL_MODEL; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettingsTests.getMapOfCommonEmbeddingSettings; +import static org.elasticsearch.xpack.inference.services.settings.RateLimitSettings.REQUESTS_PER_MINUTE_FIELD; +import static org.hamcrest.Matchers.is; + +public class JinaAIEmbeddingServiceSettingsTests extends AbstractBWCWireSerializationTestCase { + + public static JinaAIEmbeddingServiceSettings createRandom() { + SimilarityMeasure similarityMeasure = randomBoolean() ? null : randomSimilarityMeasure(); + Integer dimensions = randomBoolean() ? null : randomIntBetween(32, 256); + Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); + + var commonSettings = JinaAIServiceSettingsTests.createRandom(); + var embeddingType = randomBoolean() ? null : randomFrom(JinaAIEmbeddingType.values()); + var dimensionsSetByUser = randomBoolean(); + var multimodalModel = randomBoolean(); + + return new JinaAIEmbeddingServiceSettings( + commonSettings, + similarityMeasure, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser, + multimodalModel + ); + } + + public static JinaAIEmbeddingServiceSettings createRandomWithNoNullValues() { + SimilarityMeasure similarityMeasure = randomSimilarityMeasure(); + Integer dimensions = randomIntBetween(32, 256); + Integer maxInputTokens = randomIntBetween(128, 256); + + var commonSettings = JinaAIServiceSettingsTests.createRandom(); + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var dimensionsSetByUser = randomBoolean(); + var multimodalModel = randomBoolean(); + + return new JinaAIEmbeddingServiceSettings( + commonSettings, + similarityMeasure, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser, + multimodalModel + ); + } + + public void testFromMap_Request_CreatesSettingsCorrectly() { + var similarity = SimilarityMeasure.DOT_PRODUCT; + var dimensions = 1536; + var maxInputTokens = 512; + var model = "model"; + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var requestsPerMinute = 1234; + var multiModalModel = false; + var serviceSettings = JinaAIEmbeddingServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity.toString(), + ServiceFields.DIMENSIONS, + dimensions, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + ServiceFields.MODEL_ID, + model, + EMBEDDING_TYPE, + embeddingType.toString(), + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(REQUESTS_PER_MINUTE_FIELD, requestsPerMinute)), + MULTIMODAL_MODEL, + multiModalModel + ) + ), + ConfigurationParseContext.REQUEST + ); + + assertThat( + serviceSettings, + is( + new JinaAIEmbeddingServiceSettings( + new JinaAIServiceSettings(model, new RateLimitSettings(requestsPerMinute)), + similarity, + dimensions, + maxInputTokens, + embeddingType, + true, + multiModalModel + ) + ) + ); + } + + public void testFromMap_onlyRequiredFields() { + var model = "model"; + var serviceSettings = JinaAIEmbeddingServiceSettings.fromMap( + new HashMap<>(Map.of(MODEL_ID, model)), + TaskType.EMBEDDING, + randomFrom(ConfigurationParseContext.values()) + ); + + assertThat( + serviceSettings, + is( + new JinaAIEmbeddingServiceSettings( + new JinaAIServiceSettings(model, null), + null, + null, + null, + JinaAIEmbeddingType.FLOAT, + false, + true + ) + ) + ); + } + + public void testFromMap_InvalidSimilarity_ThrowsError() { + var similarity = "by_size"; + var thrownException = expectThrows( + ValidationException.class, + () -> JinaAIEmbeddingServiceSettings.fromMap( + new HashMap<>(Map.of(MODEL_ID, "model", ServiceFields.SIMILARITY, similarity)), + TaskType.EMBEDDING, + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + is( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%s] received. [similarity] " + + "must be one of [cosine, dot_product, l2_norm];", + similarity + ) + ) + ); + } + + public void testFromMap_InvalidEmbeddingType_ThrowsError() { + var embeddingType = "invalid"; + var thrownException = expectThrows( + ValidationException.class, + () -> JinaAIEmbeddingServiceSettings.fromMap( + new HashMap<>(Map.of(MODEL_ID, "model", EMBEDDING_TYPE, embeddingType)), + TaskType.EMBEDDING, + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + is( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%s] received. [embedding_type] " + + "must be one of [binary, bit, float];", + embeddingType + ) + ) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = new JinaAIEmbeddingServiceSettings( + new JinaAIServiceSettings("model", new RateLimitSettings(3)), + SimilarityMeasure.COSINE, + 5, + 10, + JinaAIEmbeddingType.FLOAT, + true, + true + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + assertThat(xContentResult, is(stripWhitespace(""" + { + "model_id":"model", + "rate_limit":{"requests_per_minute":3}, + "dimensions":5, + "embedding_type":"float", + "max_input_tokens":10, + "similarity":"cosine", + "multimodal_model":true, + "dimensions_set_by_user": true + }"""))); + } + + public void testUpdate() { + var settings = createRandom(); + var similarity = randomSimilarityMeasure(); + var dimensions = randomIntBetween(32, 256); + + var newSettings = settings.update(similarity, dimensions); + + var expectedSettings = new JinaAIEmbeddingServiceSettings( + settings.getCommonSettings(), + similarity, + dimensions, + settings.maxInputTokens(), + settings.getEmbeddingType(), + settings.dimensionsSetByUser(), + settings.isMultimodal() + ); + + assertThat(newSettings, is(expectedSettings)); + } + + @Override + protected Writeable.Reader instanceReader() { + return JinaAIEmbeddingServiceSettings::new; + } + + @Override + protected JinaAIEmbeddingServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected JinaAIEmbeddingServiceSettings mutateInstance(JinaAIEmbeddingServiceSettings instance) throws IOException { + var commonSettings = instance.getCommonSettings(); + var similarity = instance.similarity(); + var dimensions = instance.dimensions(); + var maxInputTokens = instance.maxInputTokens(); + var embeddingType = instance.getEmbeddingType(); + var dimensionsSetByUser = instance.dimensionsSetByUser(); + var multimodal = instance.isMultimodal(); + switch (randomInt(6)) { + case 0 -> commonSettings = randomValueOtherThan(commonSettings, JinaAIServiceSettingsTests::createRandom); + case 1 -> similarity = randomValueOtherThan(similarity, () -> randomFrom(randomSimilarityMeasure(), null)); + case 2 -> dimensions = randomValueOtherThan(dimensions, ESTestCase::randomNonNegativeIntOrNull); + case 3 -> maxInputTokens = randomValueOtherThan(maxInputTokens, () -> randomFrom(randomIntBetween(128, 256), null)); + case 4 -> embeddingType = randomValueOtherThan(embeddingType, () -> randomFrom(JinaAIEmbeddingType.values())); + case 5 -> dimensionsSetByUser = randomValueOtherThan(dimensionsSetByUser, ESTestCase::randomBoolean); + case 6 -> multimodal = randomValueOtherThan(multimodal, ESTestCase::randomBoolean); + default -> throw new AssertionError("Illegal randomisation branch"); + } + + return new JinaAIEmbeddingServiceSettings( + commonSettings, + similarity, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser, + multimodal + ); + } + + @Override + protected JinaAIEmbeddingServiceSettings mutateInstanceForVersion(JinaAIEmbeddingServiceSettings instance, TransportVersion version) { + Boolean multimodalModel = instance.isMultimodal(); + boolean dimensionsSetByUser = instance.dimensionsSetByUser(); + JinaAIEmbeddingType embeddingType = instance.getEmbeddingType(); + + // default to null for multimodal if node is on a version before embedding task support was added + if (version.supports(JINA_AI_EMBEDDING_TASK_ADDED) == false) { + multimodalModel = null; + } + // default to false for dimensionsSetByUser if node is on a version before support for setting embedding dimensions was added + if (version.supports(JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED) == false) { + dimensionsSetByUser = false; + } + // default to null embedding type if node is on a version before embedding type was introduced + if (version.supports(JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED) == false) { + embeddingType = null; + } + return new JinaAIEmbeddingServiceSettings( + instance.getCommonSettings(), + instance.similarity(), + instance.dimensions(), + instance.maxInputTokens(), + embeddingType, + dimensionsSetByUser, + multimodalModel + ); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + + public static Map getServiceSettingsMap( + String modelName, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + @Nullable JinaAIEmbeddingType embeddingType, + @Nullable Integer requestsPerMinute, + @Nullable Boolean multimodalModel + ) { + var map = getMapOfCommonEmbeddingSettings( + modelName, + similarity, + dimensions, + dimensionsSetByUser, + maxInputTokens, + embeddingType, + requestsPerMinute + ); + + if (multimodalModel != null) { + map.put(MULTIMODAL_MODEL, multimodalModel); + } + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java index dd119cfcca683..7232a576cb508 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -19,6 +20,7 @@ import java.util.Map; +import static org.elasticsearch.xpack.inference.TaskTypeTests.randomEmbeddingTaskType; import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIService.VALID_INPUT_TYPE_VALUES; import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap; import static org.hamcrest.Matchers.is; @@ -27,13 +29,13 @@ public class JinaAIEmbeddingsModelTests extends ESTestCase { public void testConstructor_usesDefaultUrlWhenNull() { - var model = createModel(null, randomAlphaOfLength(10), randomAlphaOfLength(10)); + var model = createTextEmbeddingModel(null, randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThat(model.uri().toString(), is("https://api.jina.ai/v1/embeddings")); } public void testConstructor_usesUrlWhenSpecified() { String url = "some_URL"; - var model = createModel(url, randomAlphaOfLength(10), randomAlphaOfLength(10)); + var model = createTextEmbeddingModel(url, randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThat(model.uri().toString(), is(url)); } @@ -42,7 +44,8 @@ public void testOf_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEmpty() { null, "modelName", new JinaAIEmbeddingsTaskSettings(randomFrom(VALID_INPUT_TYPE_VALUES), randomBoolean()), - "api_key" + "api_key", + randomEmbeddingTaskType() ); var overriddenModel = JinaAIEmbeddingsModel.of(model, Map.of()); @@ -54,7 +57,8 @@ public void testOf_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull() { null, "modelName", new JinaAIEmbeddingsTaskSettings(randomFrom(VALID_INPUT_TYPE_VALUES), randomBoolean()), - "api_key" + "api_key", + randomEmbeddingTaskType() ); var overriddenModel = JinaAIEmbeddingsModel.of(model, null); @@ -63,7 +67,7 @@ public void testOf_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull() { public void testOf_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEqual() { JinaAIEmbeddingsTaskSettings taskSettings = new JinaAIEmbeddingsTaskSettings(randomFrom(VALID_INPUT_TYPE_VALUES), randomBoolean()); - var model = createModel(null, "modelName", taskSettings, "api_key"); + var model = createModel(null, "modelName", taskSettings, "api_key", randomEmbeddingTaskType()); var overriddenModel = JinaAIEmbeddingsModel.of( model, @@ -75,28 +79,37 @@ public void testOf_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEqual() { public void testOf_SetsInputType_FromRequestTaskSettings_IfValid_OverridingStoredTaskSettings() { String modelName = "modelName"; String apiKey = "api_key"; - var model = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), apiKey); + TaskType taskType = randomEmbeddingTaskType(); + var model = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), apiKey, taskType); var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.SEARCH)); - var expectedModel = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.SEARCH, true), apiKey); + var expectedModel = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.SEARCH, true), apiKey, taskType); assertThat(overriddenModel, is(expectedModel)); } public void testOf_SetsLateChunking_FromRequestTaskSettings() { String modelName = "modelName"; String apiKey = "api_key"; - var model = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), apiKey); + TaskType taskType = randomEmbeddingTaskType(); + var model = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), apiKey, taskType); var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.INGEST, false)); - var expectedModel = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.INGEST, false), apiKey); + var expectedModel = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.INGEST, false), apiKey, taskType); assertThat(overriddenModel, is(expectedModel)); } /** - * Returns a model with empty task settings, service settings and chunking settings + * Returns a model with empty task settings, service settings and chunking settings, using the {@link TaskType#TEXT_EMBEDDING} task type */ - public static JinaAIEmbeddingsModel createModel(String url, String modelName, String apiKey) { - return createModel(url, modelName, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, apiKey); + public static JinaAIEmbeddingsModel createTextEmbeddingModel(String url, String modelName, String apiKey) { + return createModel(url, modelName, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, apiKey, TaskType.TEXT_EMBEDDING); + } + + /** + * Returns a model with empty task settings, service settings and chunking settings, using the {@link TaskType#EMBEDDING} task type + */ + public static JinaAIEmbeddingsModel createEmbeddingModel(String url, String modelName, String apiKey) { + return createModel(url, modelName, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, apiKey, TaskType.EMBEDDING); } /** @@ -106,10 +119,11 @@ public static JinaAIEmbeddingsModel createModel( @Nullable String url, String modelName, JinaAIEmbeddingsTaskSettings taskSettings, - String apiKey + String apiKey, + TaskType taskType ) { - var serviceSettings = getEmbeddingServiceSettings(modelName, null, null, null, null, null, false); - return createModel(url, serviceSettings, taskSettings, null, apiKey); + var serviceSettings = getEmbeddingServiceSettings(modelName, null, null, null, null, null, false, taskType, null); + return createModel(url, serviceSettings, taskSettings, null, apiKey, taskType); } /** @@ -120,10 +134,11 @@ public static JinaAIEmbeddingsModel createModel( String modelName, JinaAIEmbeddingsTaskSettings taskSettings, @Nullable ChunkingSettings chunkingSettings, - String apiKey + String apiKey, + TaskType taskType ) { - var serviceSettings = getEmbeddingServiceSettings(modelName, null, null, null, null, null, false); - return createModel(url, serviceSettings, taskSettings, chunkingSettings, apiKey); + var serviceSettings = getEmbeddingServiceSettings(modelName, null, null, null, null, null, false, taskType, null); + return createModel(url, serviceSettings, taskSettings, chunkingSettings, apiKey, taskType); } /** @@ -135,10 +150,22 @@ public static JinaAIEmbeddingsModel createModel( @Nullable JinaAIEmbeddingType embeddingType, JinaAIEmbeddingsTaskSettings taskSettings, String apiKey, - @Nullable Integer dimensions + @Nullable Integer dimensions, + TaskType taskType, + @Nullable Boolean multimodalModel ) { - var serviceSettings = getEmbeddingServiceSettings(modelName, null, null, dimensions, null, embeddingType, dimensions != null); - return createModel(url, serviceSettings, taskSettings, null, apiKey); + var serviceSettings = getEmbeddingServiceSettings( + modelName, + null, + null, + dimensions, + null, + embeddingType, + dimensions != null, + taskType, + multimodalModel + ); + return createModel(url, serviceSettings, taskSettings, null, apiKey, taskType); } public static JinaAIEmbeddingsModel createModel( @@ -152,7 +179,9 @@ public static JinaAIEmbeddingsModel createModel( JinaAIEmbeddingsTaskSettings taskSettings, @Nullable ChunkingSettings chunkingSettings, String apiKey, - boolean dimensionsSetByUser + boolean dimensionsSetByUser, + TaskType taskType, + @Nullable Boolean multimodalModel ) { var serviceSettings = getEmbeddingServiceSettings( modelName, @@ -161,17 +190,20 @@ public static JinaAIEmbeddingsModel createModel( dimensions, maxInputTokens, embeddingType, - dimensionsSetByUser + dimensionsSetByUser, + taskType, + multimodalModel ); - return createModel(url, serviceSettings, taskSettings, chunkingSettings, apiKey); + return createModel(url, serviceSettings, taskSettings, chunkingSettings, apiKey, taskType); } public static JinaAIEmbeddingsModel createModel( @Nullable String url, - JinaAIEmbeddingsServiceSettings serviceSettings, + BaseJinaAIEmbeddingsServiceSettings serviceSettings, JinaAIEmbeddingsTaskSettings taskSettings, @Nullable ChunkingSettings chunkingSettings, - String apiKey + String apiKey, + TaskType taskType ) { return new JinaAIEmbeddingsModel( "id", @@ -179,26 +211,43 @@ public static JinaAIEmbeddingsModel createModel( taskSettings, chunkingSettings, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())), - url + url, + taskType ); } - public static JinaAIEmbeddingsServiceSettings getEmbeddingServiceSettings( + public static BaseJinaAIEmbeddingsServiceSettings getEmbeddingServiceSettings( String modelName, @Nullable RateLimitSettings rateLimitSettings, @Nullable SimilarityMeasure similarity, @Nullable Integer dimensions, @Nullable Integer maxInputTokens, @Nullable JinaAIEmbeddingType embeddingType, - boolean dimensionsSetByUser + boolean dimensionsSetByUser, + TaskType taskType, + @Nullable Boolean multimodalModel ) { - return new JinaAIEmbeddingsServiceSettings( - new JinaAIServiceSettings(modelName, rateLimitSettings), - similarity, - dimensions, - maxInputTokens, - embeddingType, - dimensionsSetByUser - ); + if (taskType == TaskType.TEXT_EMBEDDING) { + return new JinaAITextEmbeddingServiceSettings( + new JinaAIServiceSettings(modelName, rateLimitSettings), + similarity, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser + ); + } else if (taskType == TaskType.EMBEDDING) { + return new JinaAIEmbeddingServiceSettings( + new JinaAIServiceSettings(modelName, rateLimitSettings), + similarity, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser, + multimodalModel + ); + } else { + throw new IllegalArgumentException("Invalid taskType: " + taskType); + } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettingsTests.java similarity index 77% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettingsTests.java index 423f631c13740..d3e1ae8ae4cb9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettingsTests.java @@ -40,26 +40,43 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; -import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettingsTests.getMapOfCommonEmbeddingSettings; import static org.elasticsearch.xpack.inference.services.settings.RateLimitSettings.REQUESTS_PER_MINUTE_FIELD; import static org.hamcrest.Matchers.is; -public class JinaAIEmbeddingsServiceSettingsTests extends AbstractBWCWireSerializationTestCase { +public class JinaAITextEmbeddingServiceSettingsTests extends AbstractBWCWireSerializationTestCase { - private static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED = TransportVersion.fromName( - "jina_ai_embedding_type_support_added" - ); - - public static JinaAIEmbeddingsServiceSettings createRandom() { - SimilarityMeasure similarityMeasure = SimilarityMeasure.DOT_PRODUCT; - Integer dimensions = 1024; + public static JinaAITextEmbeddingServiceSettings createRandom() { + SimilarityMeasure similarityMeasure = randomBoolean() ? null : randomSimilarityMeasure(); + Integer dimensions = randomBoolean() ? null : randomIntBetween(32, 256); Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); + var commonSettings = JinaAIServiceSettingsTests.createRandom(); + var embeddingType = randomBoolean() ? null : randomFrom(JinaAIEmbeddingType.values()); + var dimensionsSetByUser = randomBoolean(); + + return new JinaAITextEmbeddingServiceSettings( + commonSettings, + similarityMeasure, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser + ); + } + + public static JinaAITextEmbeddingServiceSettings createRandomWithNoNullValues() { + SimilarityMeasure similarityMeasure = randomSimilarityMeasure(); + Integer dimensions = randomIntBetween(32, 256); + Integer maxInputTokens = randomIntBetween(128, 256); + var commonSettings = JinaAIServiceSettingsTests.createRandom(); var embeddingType = randomFrom(JinaAIEmbeddingType.values()); var dimensionsSetByUser = randomBoolean(); - return new JinaAIEmbeddingsServiceSettings( + return new JinaAITextEmbeddingServiceSettings( commonSettings, similarityMeasure, dimensions, @@ -76,7 +93,7 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { var model = "model"; var embeddingType = randomFrom(JinaAIEmbeddingType.values()); var requestsPerMinute = 1234; - var serviceSettings = JinaAIEmbeddingsServiceSettings.fromMap( + var serviceSettings = JinaAITextEmbeddingServiceSettings.fromMap( new HashMap<>( Map.of( ServiceFields.SIMILARITY, @@ -99,7 +116,7 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { assertThat( serviceSettings, is( - new JinaAIEmbeddingsServiceSettings( + new JinaAITextEmbeddingServiceSettings( new JinaAIServiceSettings(model, new RateLimitSettings(requestsPerMinute)), similarity, dimensions, @@ -114,7 +131,7 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { public void testFromMap_Request_DimensionsSetByUser_IsFalse_WhenDimensionsAreNotPresent() { var url = "https://www.abc.com"; var model = "model"; - var serviceSettings = JinaAIEmbeddingsServiceSettings.fromMap( + var serviceSettings = JinaAITextEmbeddingServiceSettings.fromMap( new HashMap<>(Map.of(URL, url, ServiceFields.MODEL_ID, model)), ConfigurationParseContext.REQUEST ); @@ -122,7 +139,7 @@ public void testFromMap_Request_DimensionsSetByUser_IsFalse_WhenDimensionsAreNot assertThat( serviceSettings, is( - new JinaAIEmbeddingsServiceSettings( + new JinaAITextEmbeddingServiceSettings( new JinaAIServiceSettings(model, null), null, null, @@ -143,7 +160,7 @@ public void testFromMap_Persistent_CreatesSettingsCorrectly() { var embeddingType = randomFrom(JinaAIEmbeddingType.values()); var requestsPerMinute = 1234; var dimensionsSetByUser = randomBoolean(); - var serviceSettings = JinaAIEmbeddingsServiceSettings.fromMap( + var serviceSettings = JinaAITextEmbeddingServiceSettings.fromMap( new HashMap<>( Map.of( URL, @@ -170,7 +187,7 @@ public void testFromMap_Persistent_CreatesSettingsCorrectly() { assertThat( serviceSettings, is( - new JinaAIEmbeddingsServiceSettings( + new JinaAITextEmbeddingServiceSettings( new JinaAIServiceSettings(model, new RateLimitSettings(requestsPerMinute)), similarity, dimensions, @@ -187,7 +204,7 @@ public void testFromMap_WithModelId() { var dimensions = 1536; var maxInputTokens = 512; var model = "model"; - var serviceSettings = JinaAIEmbeddingsServiceSettings.fromMap( + var serviceSettings = JinaAITextEmbeddingServiceSettings.fromMap( new HashMap<>( Map.of( ServiceFields.SIMILARITY, @@ -206,7 +223,7 @@ public void testFromMap_WithModelId() { assertThat( serviceSettings, is( - new JinaAIEmbeddingsServiceSettings( + new JinaAITextEmbeddingServiceSettings( new JinaAIServiceSettings(model, null), similarity, dimensions, @@ -223,7 +240,7 @@ public void testFromMap_WithEmbeddingType() { var dimensions = 1536; var maxInputTokens = 512; var model = "model"; - var serviceSettings = JinaAIEmbeddingsServiceSettings.fromMap( + var serviceSettings = JinaAITextEmbeddingServiceSettings.fromMap( new HashMap<>( Map.of( ServiceFields.SIMILARITY, @@ -244,7 +261,7 @@ public void testFromMap_WithEmbeddingType() { assertThat( serviceSettings, is( - new JinaAIEmbeddingsServiceSettings( + new JinaAITextEmbeddingServiceSettings( new JinaAIServiceSettings(model, null), similarity, dimensions, @@ -260,7 +277,7 @@ public void testFromMap_InvalidSimilarity_ThrowsError() { var similarity = "by_size"; var thrownException = expectThrows( ValidationException.class, - () -> JinaAIEmbeddingsServiceSettings.fromMap( + () -> JinaAITextEmbeddingServiceSettings.fromMap( new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model", SIMILARITY, similarity)), ConfigurationParseContext.PERSISTENT ) @@ -279,7 +296,7 @@ public void testFromMap_nonPositiveDimensions_ThrowsError() { var dimensions = randomIntBetween(-5, 0); var thrownException = expectThrows( ValidationException.class, - () -> JinaAIEmbeddingsServiceSettings.fromMap( + () -> JinaAITextEmbeddingServiceSettings.fromMap( new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model", DIMENSIONS, dimensions)), randomFrom(ConfigurationParseContext.values()) ) @@ -305,7 +322,7 @@ public void testToXContent_WritesAllValues() throws IOException { var maxInputTokens = randomNonNegativeInt(); var embeddingType = randomFrom(JinaAIEmbeddingType.values()); var dimensionsSetByUser = false; - var serviceSettings = new JinaAIEmbeddingsServiceSettings( + var serviceSettings = new JinaAITextEmbeddingServiceSettings( new JinaAIServiceSettings(modelName, new RateLimitSettings(requestsPerMinute)), similarity, dimensions, @@ -337,7 +354,7 @@ public void testToXContentFragmentOfExposedFields_WritesAllValues() throws IOExc var maxInputTokens = randomNonNegativeInt(); var embeddingType = randomFrom(JinaAIEmbeddingType.values()); var dimensionsSetByUser = false; - var serviceSettings = new JinaAIEmbeddingsServiceSettings( + var serviceSettings = new JinaAITextEmbeddingServiceSettings( new JinaAIServiceSettings(modelName, new RateLimitSettings(requestsPerMinute)), similarity, dimensions, @@ -362,18 +379,37 @@ public void testToXContentFragmentOfExposedFields_WritesAllValues() throws IOExc }""", modelName, requestsPerMinute, dimensions, embeddingType, maxInputTokens, similarity)))); } + public void testUpdate() { + var settings = createRandom(); + var similarity = randomSimilarityMeasure(); + var dimensions = randomIntBetween(32, 256); + + var newSettings = settings.update(similarity, dimensions); + + var expectedSettings = new JinaAITextEmbeddingServiceSettings( + settings.getCommonSettings(), + similarity, + dimensions, + settings.maxInputTokens(), + settings.getEmbeddingType(), + settings.dimensionsSetByUser() + ); + + assertThat(newSettings, is(expectedSettings)); + } + @Override - protected Writeable.Reader instanceReader() { - return JinaAIEmbeddingsServiceSettings::new; + protected Writeable.Reader instanceReader() { + return JinaAITextEmbeddingServiceSettings::new; } @Override - protected JinaAIEmbeddingsServiceSettings createTestInstance() { + protected JinaAITextEmbeddingServiceSettings createTestInstance() { return createRandom(); } @Override - protected JinaAIEmbeddingsServiceSettings mutateInstance(JinaAIEmbeddingsServiceSettings instance) throws IOException { + protected JinaAITextEmbeddingServiceSettings mutateInstance(JinaAITextEmbeddingServiceSettings instance) throws IOException { var commonSettings = instance.getCommonSettings(); var similarity = instance.similarity(); var dimensions = instance.dimensions(); @@ -390,7 +426,7 @@ protected JinaAIEmbeddingsServiceSettings mutateInstance(JinaAIEmbeddingsService default -> throw new AssertionError("Illegal randomisation branch"); } - return new JinaAIEmbeddingsServiceSettings( + return new JinaAITextEmbeddingServiceSettings( commonSettings, similarity, dimensions, @@ -401,7 +437,10 @@ protected JinaAIEmbeddingsServiceSettings mutateInstance(JinaAIEmbeddingsService } @Override - protected JinaAIEmbeddingsServiceSettings mutateInstanceForVersion(JinaAIEmbeddingsServiceSettings instance, TransportVersion version) { + protected JinaAITextEmbeddingServiceSettings mutateInstanceForVersion( + JinaAITextEmbeddingServiceSettings instance, + TransportVersion version + ) { if (version.supports(JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED)) { return instance; } @@ -414,7 +453,7 @@ protected JinaAIEmbeddingsServiceSettings mutateInstanceForVersion(JinaAIEmbeddi embeddingType = null; } - return new JinaAIEmbeddingsServiceSettings( + return new JinaAITextEmbeddingServiceSettings( instance.getCommonSettings(), instance.similarity(), instance.dimensions(), @@ -432,13 +471,23 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { return new NamedWriteableRegistry(entries); } - public static Map getServiceSettingsMap(String model, @Nullable JinaAIEmbeddingType embeddingType) { - var map = new HashMap<>(JinaAIServiceSettingsTests.getServiceSettingsMap(model)); - - if (embeddingType != null) { - map.put(EMBEDDING_TYPE, embeddingType.toString()); - } - - return map; + public static Map getServiceSettingsMap( + String modelName, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + @Nullable JinaAIEmbeddingType embeddingType, + @Nullable Integer requestsPerMinute + ) { + return getMapOfCommonEmbeddingSettings( + modelName, + similarity, + dimensions, + dimensionsSetByUser, + maxInputTokens, + embeddingType, + requestsPerMinute + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java index c906e9203338d..317f52d970200 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java @@ -8,7 +8,10 @@ package org.elasticsearch.xpack.inference.services.jinaai.request; import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InferenceString; +import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; @@ -21,18 +24,21 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.elasticsearch.inference.InferenceString.DataFormat.BASE64; +import static org.elasticsearch.inference.InferenceString.DataType.IMAGE; +import static org.elasticsearch.xpack.inference.TaskTypeTests.randomEmbeddingTaskType; import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests.createModel; import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests.getEmbeddingServiceSettings; import static org.hamcrest.CoreMatchers.is; public class JinaAIEmbeddingsRequestEntityTests extends ESTestCase { - public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + public void testXContent_nonMultimodal_WritesAllFields_WhenTheyAreDefined() throws IOException { var modelName = "modelName"; var embeddingType = randomFrom(JinaAIEmbeddingType.values()); var lateChunking = randomBoolean(); var dimensions = randomNonNegativeInt(); var entity = new JinaAIEmbeddingsRequestEntity( - List.of("abc"), + List.of(new InferenceStringGroup("abc")), InputType.INTERNAL_INGEST, createModel( null, @@ -40,7 +46,9 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException embeddingType, new JinaAIEmbeddingsTaskSettings(InputType.INGEST, lateChunking), "apiKey", - dimensions + dimensions, + randomEmbeddingTaskType(), + false ) ); @@ -64,11 +72,20 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException ); } - public void testXContent_WritesOnlyLateChunkingField_WhenItIsTheOnlyOptionalFieldDefined() throws IOException { + public void testXContent_nonMultimodal_WritesOnlyLateChunkingField_WhenItIsTheOnlyOptionalFieldDefined() throws IOException { var entity = new JinaAIEmbeddingsRequestEntity( - List.of("abc"), + List.of(new InferenceStringGroup("abc")), InputType.INTERNAL_INGEST, - createModel(null, "modelName", null, new JinaAIEmbeddingsTaskSettings(null, false), "apiKey", null) + createModel( + null, + "modelName", + null, + new JinaAIEmbeddingsTaskSettings(null, false), + "apiKey", + null, + randomEmbeddingTaskType(), + false + ) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -79,15 +96,27 @@ public void testXContent_WritesOnlyLateChunkingField_WhenItIsTheOnlyOptionalFiel {"input":["abc"],"model":"modelName","embedding_type":"float","task":"retrieval.passage","late_chunking":false}""")); } - public void testXContent_WritesFalseLateChunkingField_WhenLateChunkingSetToTrueButInputExceedsWordCountLimit() throws IOException { + public void testXContent_nonMultimodal_WritesFalseLateChunkingField_WhenLateChunkingSetToTrueButInputExceedsWordCountLimit() + throws IOException { int wordCount = JinaAIEmbeddingsRequestEntity.MAX_WORD_COUNT_FOR_LATE_CHUNKING + 1; - var testInput = IntStream.range(0, wordCount / 2).mapToObj(i -> "word" + i).collect(Collectors.joining(" ")) + "."; + var testInput = new InferenceStringGroup( + IntStream.range(0, wordCount / 2).mapToObj(i -> "word" + i).collect(Collectors.joining(" ")) + "." + ); var testInputs = List.of(testInput, testInput, testInput); var entity = new JinaAIEmbeddingsRequestEntity( testInputs, InputType.INTERNAL_INGEST, - createModel(null, "modelName", null, new JinaAIEmbeddingsTaskSettings(null, true), "apiKey", null) + createModel( + null, + "modelName", + null, + new JinaAIEmbeddingsTaskSettings(null, true), + "apiKey", + null, + randomEmbeddingTaskType(), + false + ) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -96,14 +125,23 @@ public void testXContent_WritesFalseLateChunkingField_WhenLateChunkingSetToTrueB assertThat(xContentResult, is(Strings.format(""" {"input":["%s","%s","%s"],"model":"modelName","embedding_type":"float",\ - "task":"retrieval.passage","late_chunking":false}""", testInput, testInput, testInput))); + "task":"retrieval.passage","late_chunking":false}""", testInput.textValue(), testInput.textValue(), testInput.textValue()))); } - public void testXContent_WritesInputTypeField_WhenItIsDefinedOnlyInTaskSettings() throws IOException { + public void testXContent_nonMultimodal_WritesInputTypeField_WhenItIsDefinedOnlyInTaskSettings() throws IOException { var entity = new JinaAIEmbeddingsRequestEntity( - List.of("abc"), + List.of(new InferenceStringGroup("abc")), null, - createModel(null, "modelName", null, new JinaAIEmbeddingsTaskSettings(InputType.SEARCH, null), "apiKey", null) + createModel( + null, + "modelName", + null, + new JinaAIEmbeddingsTaskSettings(InputType.SEARCH, null), + "apiKey", + null, + randomEmbeddingTaskType(), + false + ) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -114,11 +152,20 @@ public void testXContent_WritesInputTypeField_WhenItIsDefinedOnlyInTaskSettings( {"input":["abc"],"model":"modelName","embedding_type":"float","task":"retrieval.query"}""")); } - public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { + public void testXContent_nonMultimodal_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { var entity = new JinaAIEmbeddingsRequestEntity( - List.of("abc"), + List.of(new InferenceStringGroup("abc")), null, - createModel(null, "modelName", null, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "apiKey", null) + createModel( + null, + "modelName", + null, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "apiKey", + null, + randomEmbeddingTaskType(), + false + ) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -129,11 +176,20 @@ public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws I {"input":["abc"],"model":"modelName","embedding_type":"float"}""")); } - public void testXContent_EmbeddingTypesBit() throws IOException { + public void testXContent_nonMultimodal_EmbeddingTypesBit() throws IOException { var entity = new JinaAIEmbeddingsRequestEntity( - List.of("abc"), + List.of(new InferenceStringGroup("abc")), InputType.CLUSTERING, - createModel(null, "model", JinaAIEmbeddingType.BIT, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "apiKey", null) + createModel( + null, + "model", + JinaAIEmbeddingType.BIT, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "apiKey", + null, + randomEmbeddingTaskType(), + false + ) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -144,11 +200,20 @@ public void testXContent_EmbeddingTypesBit() throws IOException { {"input":["abc"],"model":"model","embedding_type":"binary","task":"separation"}""")); } - public void testXContent_EmbeddingTypesBinary() throws IOException { + public void testXContent_nonMultimodal_EmbeddingTypesBinary() throws IOException { var entity = new JinaAIEmbeddingsRequestEntity( - List.of("abc"), + List.of(new InferenceStringGroup("abc")), InputType.SEARCH, - createModel(null, "model", JinaAIEmbeddingType.BINARY, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "apiKey", null) + createModel( + null, + "model", + JinaAIEmbeddingType.BINARY, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "apiKey", + null, + randomEmbeddingTaskType(), + false + ) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -159,12 +224,13 @@ public void testXContent_EmbeddingTypesBinary() throws IOException { {"input":["abc"],"model":"model","embedding_type":"binary","task":"retrieval.query"}""")); } - public void testXContent_doesNotWriteDimensions_whenDimensionsSetByUserIsFalse() throws IOException { - var serviceSettings = getEmbeddingServiceSettings("modelName", null, null, 512, null, null, false); + public void testXContent_nonMultimodal_doesNotWriteDimensions_whenDimensionsSetByUserIsFalse() throws IOException { + var taskType = randomEmbeddingTaskType(); + var serviceSettings = getEmbeddingServiceSettings("modelName", null, null, 512, null, null, false, taskType, false); var entity = new JinaAIEmbeddingsRequestEntity( - List.of("abc"), + List.of(new InferenceStringGroup("abc")), null, - createModel(null, serviceSettings, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, null, "apiKey") + createModel(null, serviceSettings, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, null, "apiKey", taskType) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -174,4 +240,105 @@ public void testXContent_doesNotWriteDimensions_whenDimensionsSetByUserIsFalse() assertThat(xContentResult, is(""" {"input":["abc"],"model":"modelName","embedding_type":"float"}""")); } + + public void testXContent_multimodal_WritesAllFields_WhenTheyAreDefined() throws IOException { + var modelName = "modelName"; + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var lateChunking = randomBoolean(); + var dimensions = randomNonNegativeInt(); + String textInput = "text input"; + String imageInput = "image input"; + var entity = new JinaAIEmbeddingsRequestEntity( + List.of(new InferenceStringGroup(textInput), new InferenceStringGroup(new InferenceString(IMAGE, imageInput))), + InputType.INTERNAL_INGEST, + createModel( + null, + modelName, + embeddingType, + new JinaAIEmbeddingsTaskSettings(InputType.INGEST, lateChunking), + "apiKey", + dimensions, + TaskType.EMBEDDING, + true + ) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat( + xContentResult, + is( + Strings.format( + """ + {"input":[{"text":"%s"},{"image":"%s"}],"model":"%s",\ + "embedding_type":"%s","task":"retrieval.passage","late_chunking":false,"dimensions":%d}""", + textInput, + imageInput, + modelName, + embeddingType.toRequestString(), + dimensions + ) + ) + ); + } + + public void testXContent_multimodal_WritesTrueLateChunkingField_WhenLateChunkingSetToTrueAndInputContainsOnlyTextInput() + throws IOException { + + String textInput1 = "text input 1"; + String textInput2 = "text input 2"; + var entity = new JinaAIEmbeddingsRequestEntity( + List.of(new InferenceStringGroup(textInput1), new InferenceStringGroup(textInput2)), + InputType.INTERNAL_INGEST, + createModel(null, "modelName", null, new JinaAIEmbeddingsTaskSettings(null, true), "apiKey", null, TaskType.EMBEDDING, true) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(Strings.format(""" + {"input":[{"text":"%s"},{"text":"%s"}],"model":"modelName","embedding_type":"float",\ + "task":"retrieval.passage","late_chunking":true}""", textInput1, textInput2))); + } + + public void testXContent_multimodal_WritesFalseLateChunkingField_WhenLateChunkingSetToTrueAndInputContainsNonTextInput() + throws IOException { + + String textInput = "text input"; + String imageInput = "image input"; + var entity = new JinaAIEmbeddingsRequestEntity( + List.of(new InferenceStringGroup(textInput), new InferenceStringGroup(List.of(new InferenceString(IMAGE, BASE64, imageInput)))), + InputType.INTERNAL_INGEST, + createModel(null, "modelName", null, new JinaAIEmbeddingsTaskSettings(null, true), "apiKey", null, TaskType.EMBEDDING, true) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(Strings.format(""" + {"input":[{"text":"%s"},{"image":"%s"}],"model":"modelName","embedding_type":"float",\ + "task":"retrieval.passage","late_chunking":false}""", textInput, imageInput))); + } + + public void testXContent_multimodal_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { + String textInput = "text input"; + String imageInput = "image input"; + String modelName = "modelName"; + var entity = new JinaAIEmbeddingsRequestEntity( + List.of(new InferenceStringGroup(textInput), new InferenceStringGroup(new InferenceString(IMAGE, imageInput))), + null, + createModel(null, modelName, null, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "apiKey", null, TaskType.EMBEDDING, null) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(Strings.format(""" + {"input":[{"text":"%s"},{"image":"%s"}],"model":"%s","embedding_type":"float"}""", textInput, imageInput, modelName))); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestTests.java index 4e8c79bf72dfe..5567ce83a9af9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestTests.java @@ -9,7 +9,10 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.inference.InferenceString; +import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.InputTypeTests; @@ -19,6 +22,8 @@ import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -28,7 +33,7 @@ import static org.hamcrest.Matchers.is; public class JinaAIEmbeddingsRequestTests extends ESTestCase { - public void testCreateRequest_AllOptionsDefined() throws IOException { + public void testCreateRequest_AllOptionsDefined_textEmbedding() throws IOException { var inputType = InputTypeTests.randomWithNull(); boolean lateChunking = randomBoolean(); var modelName = "modelName"; @@ -37,7 +42,7 @@ public void testCreateRequest_AllOptionsDefined() throws IOException { var input = List.of("abc"); var apiKey = "api-key"; var dimensions = 512; - var request = createRequest( + var request = createTextOnlyRequest( input, inputType, JinaAIEmbeddingsModelTests.createModel( @@ -46,7 +51,9 @@ public void testCreateRequest_AllOptionsDefined() throws IOException { embeddingType, new JinaAIEmbeddingsTaskSettings(null, lateChunking), apiKey, - dimensions + dimensions, + TaskType.TEXT_EMBEDDING, + null ) ); @@ -60,47 +67,82 @@ public void testCreateRequest_AllOptionsDefined() throws IOException { assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + apiKey)); assertThat(httpPost.getLastHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(), is(JinaAIUtils.ELASTIC_REQUEST_SOURCE)); + var expectedRequestMap = new HashMap<>( + Map.of( + "input", + input, + "model", + modelName, + "embedding_type", + embeddingType.toRequestString(), + "late_chunking", + lateChunking, + "dimensions", + dimensions + ) + ); + if (InputType.isSpecified(inputType)) { + expectedRequestMap.put("task", convertInputType(inputType)); + } + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, is(expectedRequestMap)); + } + + public void testCreateRequest_AllOptionsDefined_multimodalEmbedding() throws IOException { + var inputType = InputTypeTests.randomWithNull(); + boolean lateChunking = randomBoolean(); + var modelName = "modelName"; + var url = "url"; + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var input = List.of("abc"); + var apiKey = "api-key"; + var dimensions = 512; + var request = createMultimodalRequest( + input, + inputType, + JinaAIEmbeddingsModelTests.createModel( + url, + modelName, + embeddingType, + new JinaAIEmbeddingsTaskSettings(null, lateChunking), + apiKey, + dimensions, + TaskType.EMBEDDING, + true + ) + ); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is(url)); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + apiKey)); + assertThat(httpPost.getLastHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(), is(JinaAIUtils.ELASTIC_REQUEST_SOURCE)); + + var expectedRequestMap = new HashMap<>( + Map.of( + "input", + List.of(Map.of("image", input.getFirst())), + "model", + modelName, + "embedding_type", + embeddingType.toRequestString(), + "late_chunking", + false, + "dimensions", + dimensions + ) + ); if (InputType.isSpecified(inputType)) { - var convertedInputType = convertInputType(inputType); - assertThat( - requestMap, - is( - Map.of( - "input", - input, - "model", - modelName, - "embedding_type", - embeddingType.toRequestString(), - "task", - convertedInputType, - "late_chunking", - lateChunking, - "dimensions", - dimensions - ) - ) - ); - } else { - assertThat( - requestMap, - is( - Map.of( - "input", - input, - "model", - modelName, - "embedding_type", - embeddingType.toRequestString(), - "late_chunking", - lateChunking, - "dimensions", - dimensions - ) - ) - ); + expectedRequestMap.put("task", convertInputType(inputType)); } + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, is(expectedRequestMap)); } public void testCreateRequest_TaskSettingsInputType() throws IOException { @@ -111,7 +153,7 @@ public void testCreateRequest_TaskSettingsInputType() throws IOException { List input = List.of("abc"); String apiKey = "api-key"; int dimensions = 512; - var request = createRequest( + var request = createTextOnlyRequest( input, null, JinaAIEmbeddingsModelTests.createModel( @@ -120,7 +162,9 @@ public void testCreateRequest_TaskSettingsInputType() throws IOException { embeddingType, new JinaAIEmbeddingsTaskSettings(inputType, null), apiKey, - dimensions + dimensions, + TaskType.TEXT_EMBEDDING, + null ) ); @@ -134,32 +178,15 @@ public void testCreateRequest_TaskSettingsInputType() throws IOException { assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + apiKey)); assertThat(httpPost.getLastHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(), is(JinaAIUtils.ELASTIC_REQUEST_SOURCE)); - var requestMap = entityAsMap(httpPost.getEntity().getContent()); + var expectedRequestMap = new HashMap<>( + Map.of("input", input, "model", modelName, "embedding_type", embeddingType.toRequestString(), "dimensions", dimensions) + ); if (InputType.isSpecified(inputType)) { - var convertedInputType = convertInputType(inputType); - assertThat( - requestMap, - is( - Map.of( - "input", - input, - "model", - modelName, - "embedding_type", - embeddingType.toRequestString(), - "task", - convertedInputType, - "dimensions", - dimensions - ) - ) - ); - } else { - assertThat( - requestMap, - is(Map.of("input", input, "model", modelName, "embedding_type", embeddingType.toRequestString(), "dimensions", dimensions)) - ); + expectedRequestMap.put("task", convertInputType(inputType)); } + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, is(expectedRequestMap)); } public void testCreateRequest_RequestInputTypeTakesPrecedence() throws IOException { @@ -169,10 +196,16 @@ public void testCreateRequest_RequestInputTypeTakesPrecedence() throws IOExcepti var url = "url"; List input = List.of("abc"); String apiKey = "api-key"; - var request = createRequest( + var request = createTextOnlyRequest( input, requestInputType, - JinaAIEmbeddingsModelTests.createModel(url, modelName, new JinaAIEmbeddingsTaskSettings(taskSettingsInputType, null), apiKey) + JinaAIEmbeddingsModelTests.createModel( + url, + modelName, + new JinaAIEmbeddingsTaskSettings(taskSettingsInputType, null), + apiKey, + TaskType.TEXT_EMBEDDING + ) ); var httpRequest = request.createHttpRequest(); @@ -185,19 +218,36 @@ public void testCreateRequest_RequestInputTypeTakesPrecedence() throws IOExcepti assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + apiKey)); assertThat(httpPost.getLastHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(), is(JinaAIUtils.ELASTIC_REQUEST_SOURCE)); - var requestMap = entityAsMap(httpPost.getEntity().getContent()); + var expectedRequestMap = new HashMap<>(Map.of("input", input, "model", modelName, "embedding_type", "float")); if (InputType.isSpecified(requestInputType)) { - var convertedInputType = convertInputType(requestInputType); - assertThat(requestMap, is(Map.of("input", input, "model", modelName, "embedding_type", "float", "task", convertedInputType))); + expectedRequestMap.put("task", convertInputType(requestInputType)); } else if (InputType.isSpecified(taskSettingsInputType)) { - var convertedInputType = convertInputType(taskSettingsInputType); - assertThat(requestMap, is(Map.of("input", input, "model", modelName, "embedding_type", "float", "task", convertedInputType))); - } else { - assertThat(requestMap, is(Map.of("input", input, "model", modelName, "embedding_type", "float"))); + expectedRequestMap.put("task", convertInputType(taskSettingsInputType)); + } + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, is(expectedRequestMap)); + } + + public static JinaAIEmbeddingsRequest createMultimodalRequest(List inputs, InputType inputType, JinaAIEmbeddingsModel model) { + boolean isTextInput = false; + List convertedInput = new ArrayList<>(); + for (String input : inputs) { + InferenceString inferenceString; + if (isTextInput) { + inferenceString = new InferenceString(InferenceString.DataType.TEXT, InferenceString.DataFormat.TEXT, input); + } else { + inferenceString = new InferenceString(InferenceString.DataType.IMAGE, InferenceString.DataFormat.BASE64, input); + } + isTextInput = isTextInput == false; + var inferenceStringGroup = new InferenceStringGroup(inferenceString); + convertedInput.add(inferenceStringGroup); } + return new JinaAIEmbeddingsRequest(convertedInput, inputType, model); } - public static JinaAIEmbeddingsRequest createRequest(List input, InputType inputType, JinaAIEmbeddingsModel model) { - return new JinaAIEmbeddingsRequest(input, inputType, model); + public static JinaAIEmbeddingsRequest createTextOnlyRequest(List inputs, InputType inputType, JinaAIEmbeddingsModel model) { + List convertedInput = inputs.stream().map(InferenceStringGroup::new).toList(); + return new JinaAIEmbeddingsRequest(convertedInput, inputType, model); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java index a94dabb1bab63..982be73f42046 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java @@ -20,7 +20,6 @@ import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import java.io.IOException; -import java.util.HashMap; import java.util.Map; import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; @@ -70,7 +69,11 @@ protected JinaAIRerankServiceSettings mutateInstanceForVersion(JinaAIRerankServi return instance; } - public static Map getServiceSettingsMap(@Nullable String url, @Nullable String model) { - return new HashMap<>(JinaAIServiceSettingsTests.getServiceSettingsMap(model)); + public static Map getServiceSettingsMap(String model) { + return getServiceSettingsMap(model, null); + } + + public static Map getServiceSettingsMap(String model, @Nullable Integer requestsPerMinute) { + return JinaAIServiceSettingsTests.getServiceSettingsMap(model, requestsPerMinute); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettingsTests.java index 521a7e861371e..f33c18c5a39d7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettingsTests.java @@ -119,10 +119,6 @@ protected JinaAIRerankTaskSettings mutateInstance(JinaAIRerankTaskSettings insta } } - public static Map getTaskSettingsMapEmpty() { - return new HashMap<>(); - } - public static Map getTaskSettingsMap(@Nullable Integer topNDocumentsOnly, @Nullable Boolean returnDocuments) { var map = new HashMap(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java index 5dab78b37e867..86270cf1c18b5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java @@ -10,27 +10,46 @@ import org.apache.http.HttpResponse; import org.elasticsearch.common.ParsingException; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.EmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.EmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.EmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.GenericDenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.GenericDenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.jinaai.request.JinaAIEmbeddingsRequestTests; +import org.hamcrest.Matchers; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; +import static org.elasticsearch.xpack.inference.TaskTypeTests.randomEmbeddingTaskType; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType.BINARY; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType.BIT; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType.FLOAT; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; import static org.mockito.Mockito.mock; public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase { - public void testFromResponse_CreatesResultsForASingleItem() throws IOException { + public void testFromResponse_CreatesResultsForASingleItem_textEmbeddingTask() throws IOException { + testFromResponse_singleItem(TaskType.TEXT_EMBEDDING); + } + + public void testFromResponse_CreatesResultsForASingleItem_embeddingTask() throws IOException { + testFromResponse_singleItem(TaskType.EMBEDDING); + } + + private static void testFromResponse_singleItem(TaskType taskType) throws IOException { String responseJson = """ { "object": "list", @@ -53,22 +72,30 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { """; InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( + JinaAIEmbeddingsRequestTests.createTextOnlyRequest( List.of("abc"), InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") + JinaAIEmbeddingsModelTests.createModel(null, "modelName", JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "secret", taskType) ), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class)); + assertResultsType(taskType, JinaAIEmbeddingType.FLOAT, parsedResults); assertThat( - ((DenseEmbeddingFloatResults) parsedResults).embeddings(), - is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) + ((EmbeddingFloatResults) parsedResults).embeddings(), + Matchers.is(List.of(new EmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) ); } - public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { + public void testFromResponse_CreatesResultsForMultipleItems_textEmbeddingTask() throws IOException { + testFromResponse_multipleItems(TaskType.TEXT_EMBEDDING); + } + + public void testFromResponse_CreatesResultsForMultipleItems_embeddingTask() throws IOException { + testFromResponse_multipleItems(TaskType.EMBEDDING); + } + + private static void testFromResponse_multipleItems(TaskType taskType) throws IOException { String responseJson = """ { "object": "list", @@ -99,21 +126,21 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException """; InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( + JinaAIEmbeddingsRequestTests.createTextOnlyRequest( List.of("abc"), InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") + JinaAIEmbeddingsModelTests.createModel(null, "modelName", JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "secret", taskType) ), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class)); + assertResultsType(taskType, JinaAIEmbeddingType.FLOAT, parsedResults); assertThat( - ((DenseEmbeddingFloatResults) parsedResults).embeddings(), + ((EmbeddingFloatResults) parsedResults).embeddings(), is( List.of( - new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) + new EmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), + new EmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) ) ) ); @@ -143,14 +170,7 @@ public void testFromResponse_FailsWhenDataFieldIsNotPresent() { var thrownException = expectThrows( IllegalStateException.class, - () -> JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( - List.of("abc"), - InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") - ), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) - ) + () -> callFromResponse(responseJson, randomFrom(JinaAIEmbeddingType.values())) ); assertThat(thrownException.getMessage(), is("Failed to find required field [data] in JinaAI embeddings response")); @@ -180,14 +200,7 @@ public void testFromResponse_FailsWhenDataFieldNotAnArray() { var thrownException = expectThrows( ParsingException.class, - () -> JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( - List.of("abc"), - InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") - ), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) - ) + () -> callFromResponse(responseJson, randomFrom(JinaAIEmbeddingType.values())) ); assertThat( @@ -220,14 +233,7 @@ public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() { var thrownException = expectThrows( IllegalStateException.class, - () -> JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( - List.of("abc"), - InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") - ), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) - ) + () -> callFromResponse(responseJson, randomFrom(JinaAIEmbeddingType.values())) ); assertThat(thrownException.getMessage(), is("Failed to find required field [embedding] in JinaAI embeddings response")); @@ -256,14 +262,7 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAString() { var thrownException = expectThrows( ParsingException.class, - () -> JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( - List.of("abc"), - InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") - ), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) - ) + () -> callFromResponse(responseJson, randomFrom(JinaAIEmbeddingType.values())) ); assertThat( @@ -272,7 +271,7 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAString() { ); } - public void testFromResponse_SucceedsWhenEmbeddingType_IsBinary() throws IOException { + public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { String responseJson = """ { "object": "list", @@ -281,11 +280,7 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBinary() throws IOExcep "object": "embedding", "index": 0, "embedding": [ - -55, - 74, - 101, - 67, - 83 + {} ] } ], @@ -297,29 +292,18 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBinary() throws IOExcep } """; - InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( - List.of("abc"), - InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel( - null, - "modelName", - JinaAIEmbeddingType.BINARY, - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - "secret", - null - ) - ), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + var thrownException = expectThrows( + ParsingException.class, + () -> callFromResponse(responseJson, randomFrom(JinaAIEmbeddingType.values())) ); assertThat( - ((DenseEmbeddingBitResults) parsedResults).embeddings(), - is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [VALUE_NUMBER] but found [START_OBJECT]") ); } - public void testFromResponse_SucceedsWhenEmbeddingType_IsBit() throws IOException { + public void testFromResponse_withBitEmbeddingType_FailsWhenEmbeddingValueIsLargerThanByte() { String responseJson = """ { "object": "list", @@ -328,11 +312,7 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBit() throws IOExceptio "object": "embedding", "index": 0, "embedding": [ - -55, - 74, - 101, - 67, - 83 + -1024 ] } ], @@ -344,29 +324,41 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBit() throws IOExceptio } """; - InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( + var thrownException = expectThrows(IllegalArgumentException.class, () -> callFromResponse(responseJson, randomFrom(BIT, BINARY))); + + assertThat(thrownException.getMessage(), is("Value [-1024] is out of range for a byte")); + } + + private static void callFromResponse(String responseJson, JinaAIEmbeddingType embeddingType) throws IOException { + JinaAIEmbeddingsResponseEntity.fromResponse( + JinaAIEmbeddingsRequestTests.createTextOnlyRequest( List.of("abc"), InputTypeTests.randomWithNull(), JinaAIEmbeddingsModelTests.createModel( null, "modelName", - JinaAIEmbeddingType.BIT, + embeddingType, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "secret", - null + null, + randomEmbeddingTaskType(), + false ) ), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); + } - assertThat( - ((DenseEmbeddingBitResults) parsedResults).embeddings(), - is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) - ); + public void testFromResponse_SucceedsWhenEmbeddingType_IsBinary() throws IOException { + fromResponse_withNonFloatEmbeddingType(BINARY); } - public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { + public void testFromResponse_SucceedsWhenEmbeddingType_IsBit() throws IOException { + fromResponse_withNonFloatEmbeddingType(BIT); + } + + private static void fromResponse_withNonFloatEmbeddingType(JinaAIEmbeddingType embeddingType) throws IOException { + assertThat(embeddingType, oneOf(BIT, BINARY)); String responseJson = """ { "object": "list", @@ -375,7 +367,11 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { "object": "embedding", "index": 0, "embedding": [ - {} + -55, + 74, + 101, + 67, + 83 ] } ], @@ -387,21 +383,29 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { } """; - var thrownException = expectThrows( - ParsingException.class, - () -> JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( - List.of("abc"), - InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") - ), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) - ) + var taskType = randomEmbeddingTaskType(); + InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( + JinaAIEmbeddingsRequestTests.createTextOnlyRequest( + List.of("abc"), + InputTypeTests.randomWithNull(), + JinaAIEmbeddingsModelTests.createModel( + null, + "modelName", + embeddingType, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "secret", + null, + taskType, + false + ) + ), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); + assertResultsType(taskType, embeddingType, parsedResults); assertThat( - thrownException.getMessage(), - is("Failed to parse object: expecting token of type [VALUE_NUMBER] but found [START_OBJECT]") + ((EmbeddingBitResults) parsedResults).embeddings(), + is(List.of(new EmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) ); } @@ -446,24 +450,42 @@ public void testFieldsInDifferentOrderServer() throws IOException { } }"""; - DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( + TaskType taskType = randomEmbeddingTaskType(); + InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( + JinaAIEmbeddingsRequestTests.createTextOnlyRequest( List.of("abc"), InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") + JinaAIEmbeddingsModelTests.createModel(null, "modelName", JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "secret", taskType) ), new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8)) ); + assertResultsType(taskType, FLOAT, parsedResults); assertThat( - parsedResults.embeddings(), + ((EmbeddingFloatResults) parsedResults).embeddings(), is( List.of( - new DenseEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), - new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), - new DenseEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) + new EmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), + new EmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), + new EmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) ) ) ); } + + private static void assertResultsType(TaskType taskType, JinaAIEmbeddingType embeddingType, InferenceServiceResults parsedResults) { + if (taskType.equals(TaskType.TEXT_EMBEDDING)) { + switch (embeddingType) { + case FLOAT -> assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class)); + case BIT, BINARY -> assertThat(parsedResults, instanceOf(DenseEmbeddingBitResults.class)); + } + } else if (taskType.equals(TaskType.EMBEDDING)) { + switch (embeddingType) { + case FLOAT -> assertThat(parsedResults, instanceOf(GenericDenseEmbeddingFloatResults.class)); + case BIT, BINARY -> assertThat(parsedResults, instanceOf(GenericDenseEmbeddingBitResults.class)); + } + } else { + throw new IllegalArgumentException("Invalid taskType: " + taskType); + } + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidatorTests.java index 7431ca06b8264..06c90f3d73575 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidatorTests.java @@ -13,8 +13,10 @@ import org.elasticsearch.inference.EmbeddingRequest; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.junit.Before; @@ -22,6 +24,7 @@ import org.mockito.Mock; import java.util.List; +import java.util.Map; import static org.elasticsearch.xpack.inference.services.validation.SimpleEmbeddingServiceIntegrationValidator.TEST_IMAGE_BASE64_INPUT; import static org.elasticsearch.xpack.inference.services.validation.SimpleEmbeddingServiceIntegrationValidator.TEST_TEXT_INPUT; @@ -37,10 +40,6 @@ public class SimpleEmbeddingServiceIntegrationValidatorTests extends ESTestCase { - private static final EmbeddingRequest EXPECTED_REQUEST = new EmbeddingRequest( - List.of(TEST_TEXT_INPUT, TEST_IMAGE_BASE64_INPUT), - InputType.INTERNAL_INGEST - ); private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE; @Mock @@ -48,6 +47,8 @@ public class SimpleEmbeddingServiceIntegrationValidatorTests extends ESTestCase @Mock private Model mockModel; @Mock + private ServiceSettings mockServiceSettings; + @Mock private ActionListener mockActionListener; @Mock private InferenceServiceResults mockInferenceServiceResults; @@ -61,11 +62,14 @@ public void setup() { underTest = new SimpleEmbeddingServiceIntegrationValidator(); when(mockActionListener.delegateFailureAndWrap(any())).thenCallRealMethod(); + when(mockModel.getServiceSettings()).thenReturn(mockServiceSettings); + when(mockServiceSettings.isMultimodal()).thenReturn(randomBoolean()); } public void testValidate_ServiceThrowsException() { + var expectedRequest = getExpectedRequest(); doThrow(ElasticsearchStatusException.class).when(mockInferenceService) - .embeddingInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); + .embeddingInfer(eq(mockModel), eq(expectedRequest), eq(TIMEOUT), any()); assertThrows( ElasticsearchStatusException.class, @@ -114,11 +118,12 @@ public void testValidate_CallsListenerOnFailure_WhenServiceThrowsException() { } private void mockSuccessfulCallToService(InferenceServiceResults result) { + var expectedRequest = getExpectedRequest(); doAnswer(ans -> { ActionListener responseListener = ans.getArgument(3); responseListener.onResponse(result); return null; - }).when(mockInferenceService).embeddingInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); + }).when(mockInferenceService).embeddingInfer(eq(mockModel), eq(expectedRequest), eq(TIMEOUT), any()); underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } @@ -128,17 +133,30 @@ private void mockNullResponseFromService() { } private void mockFailureResponseFromService(Exception exception) { + var expectedRequest = getExpectedRequest(); doAnswer(ans -> { ActionListener responseListener = ans.getArgument(3); responseListener.onFailure(exception); return null; - }).when(mockInferenceService).embeddingInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); + }).when(mockInferenceService).embeddingInfer(eq(mockModel), eq(expectedRequest), eq(TIMEOUT), any()); underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } private void verifyCallToService() { - verify(mockInferenceService).embeddingInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); + var expectedRequest = getExpectedRequest(); + verify(mockModel).getServiceSettings(); + verify(mockInferenceService).embeddingInfer(eq(mockModel), eq(expectedRequest), eq(TIMEOUT), any()); verifyNoMoreInteractions(mockInferenceService, mockModel, mockActionListener, mockInferenceServiceResults); } + + private EmbeddingRequest getExpectedRequest() { + List inputs; + if (mockServiceSettings.isMultimodal()) { + inputs = List.of(TEST_TEXT_INPUT, TEST_IMAGE_BASE64_INPUT); + } else { + inputs = List.of(TEST_TEXT_INPUT); + } + return new EmbeddingRequest(inputs, InputType.INTERNAL_INGEST, Map.of()); + } } From 1eb97bbfff800dfab418e8175750f62607a08ab3 Mon Sep 17 00:00:00 2001 From: Donal Evans Date: Wed, 7 Jan 2026 13:38:55 -0800 Subject: [PATCH 2/5] Update docs/changelog/140323.yaml --- docs/changelog/140323.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/140323.yaml diff --git a/docs/changelog/140323.yaml b/docs/changelog/140323.yaml new file mode 100644 index 0000000000000..1ba19da787a22 --- /dev/null +++ b/docs/changelog/140323.yaml @@ -0,0 +1,5 @@ +pr: 140323 +summary: "[Inference API] Add support for embedding task to JinaAI service" +area: Inference +type: enhancement +issues: [] From f84934c7e09153966d8b5483ee97bb0c4a85aa8e Mon Sep 17 00:00:00 2001 From: donalevans Date: Thu, 8 Jan 2026 16:31:20 -0800 Subject: [PATCH 3/5] Review feedback - Refactor BaseJinaAIEmbeddingsServiceSettings.fromMap() to use generics - Make multimodalModel field non-optional and introduce abstract optionallyWriteMultimodalField() method to control whether it is written to XContent for implementing classes - Move tests for BaseJinaAIEmbeddingsServiceSettings.fromMap() to the test classes for JinaAITextEmbeddingServiceSettings and JinaAIEmbeddingServiceSettings --- .../BaseJinaAIEmbeddingsServiceSettings.java | 88 ++++---- .../JinaAIEmbeddingServiceSettings.java | 26 ++- .../embeddings/JinaAIEmbeddingsModel.java | 15 +- .../JinaAITextEmbeddingServiceSettings.java | 35 ++- .../services/jinaai/JinaAIServiceTests.java | 2 +- ...eJinaAIEmbeddingsServiceSettingsTests.java | 204 ----------------- .../JinaAIEmbeddingServiceSettingsTests.java | 110 ++++++---- .../JinaAIEmbeddingsModelTests.java | 15 +- ...naAITextEmbeddingServiceSettingsTests.java | 207 ++++++------------ .../JinaAIEmbeddingsRequestEntityTests.java | 2 +- .../request/JinaAIEmbeddingsRequestTests.java | 4 +- 11 files changed, 247 insertions(+), 461 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettings.java index 4cd742831afb8..833c58b3118e4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettings.java @@ -16,7 +16,6 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; @@ -27,13 +26,13 @@ import java.util.EnumSet; import java.util.Map; import java.util.Objects; +import java.util.function.BiFunction; import static org.elasticsearch.inference.EmbeddingRequest.JINA_AI_EMBEDDING_TASK_ADDED; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS_SET_BY_USER; import static org.elasticsearch.xpack.inference.services.ServiceFields.EMBEDDING_TYPE; import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; -import static org.elasticsearch.xpack.inference.services.ServiceFields.MULTIMODAL_MODEL; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; @@ -48,19 +47,36 @@ public abstract class BaseJinaAIEmbeddingsServiceSettings extends FilteredXConte "jina_ai_embedding_dimensions_support_added" ); - static BaseJinaAIEmbeddingsServiceSettings fromMap(Map map, TaskType taskType, ConfigurationParseContext context) { - Objects.requireNonNull(taskType); + @FunctionalInterface + public interface ConstructorInvoker { + T construct( + JinaAIServiceSettings commonSettings, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + @Nullable JinaAIEmbeddingType embeddingType, + boolean dimensionsSetByUser, + boolean multimodalModel + ); + } + + static T fromMap( + Map map, + ConfigurationParseContext context, + BiFunction, ValidationException, Boolean> handleMultimodalModelField, + ConstructorInvoker constructor + ) { ValidationException validationException = new ValidationException(); var commonServiceSettings = JinaAIServiceSettings.fromMap(map, context); SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); Integer dimensions = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); - Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class, validationException); JinaAIEmbeddingType embeddingType = parseEmbeddingType(map, validationException); Boolean dimensionsSetByUser; if (context == ConfigurationParseContext.PERSISTENT) { - dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class); + dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class, validationException); if (dimensionsSetByUser == null) { dimensionsSetByUser = Boolean.FALSE; } @@ -68,39 +84,21 @@ static BaseJinaAIEmbeddingsServiceSettings fromMap(Map map, Task dimensionsSetByUser = dimensions != null; } - Boolean multimodalModel = null; - // Do not remove the MULTIMODAL_MODEL field from the map for TEXT_EMBEDDING since it's not supported - if (taskType == TaskType.EMBEDDING) { - multimodalModel = removeAsType(map, MULTIMODAL_MODEL, Boolean.class); - if (multimodalModel == null) { - multimodalModel = true; - } - } + boolean multimodalModel = handleMultimodalModelField.apply(map, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - if (taskType == TaskType.EMBEDDING) { - return new JinaAIEmbeddingServiceSettings( - commonServiceSettings, - similarity, - dimensions, - maxInputTokens, - embeddingType, - dimensionsSetByUser, - multimodalModel - ); - } else { - return new JinaAITextEmbeddingServiceSettings( - commonServiceSettings, - similarity, - dimensions, - maxInputTokens, - embeddingType, - dimensionsSetByUser - ); - } + return constructor.construct( + commonServiceSettings, + similarity, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser, + multimodalModel + ); } static JinaAIEmbeddingType parseEmbeddingType(Map map, ValidationException validationException) { @@ -134,7 +132,7 @@ public static BaseJinaAIEmbeddingsServiceSettings updateEmbeddingDetails( private final Integer maxInputTokens; private final JinaAIEmbeddingType embeddingType; private final boolean dimensionsSetByUser; - private final Boolean multimodalModel; + private final boolean multimodalModel; public BaseJinaAIEmbeddingsServiceSettings( JinaAIServiceSettings commonSettings, @@ -143,7 +141,7 @@ public BaseJinaAIEmbeddingsServiceSettings( @Nullable Integer maxInputTokens, @Nullable JinaAIEmbeddingType embeddingType, boolean dimensionsSetByUser, - @Nullable Boolean multimodalModel + boolean multimodalModel ) { this.commonSettings = commonSettings; this.similarity = similarity; @@ -172,18 +170,12 @@ public BaseJinaAIEmbeddingsServiceSettings(StreamInput in) throws IOException { } if (in.getTransportVersion().supports(JINA_AI_EMBEDDING_TASK_ADDED)) { - this.multimodalModel = in.readOptionalBoolean(); + this.multimodalModel = in.readBoolean(); } else { - this.multimodalModel = null; + this.multimodalModel = false; } } - /** - * Returns whether this {@link BaseJinaAIEmbeddingsServiceSettings} defaults to supporting multimodal inputs or not - * @return {@code true} if these settings default to supporting multimodal inputs - */ - public abstract boolean getDefaultMultimodal(); - /** * Returns a new {@link BaseJinaAIEmbeddingsServiceSettings} with updated similarity and dimensions but all other fields unchanged * @param similarity the new similarity @@ -192,6 +184,8 @@ public BaseJinaAIEmbeddingsServiceSettings(StreamInput in) throws IOException { */ public abstract BaseJinaAIEmbeddingsServiceSettings update(SimilarityMeasure similarity, Integer dimensions); + protected abstract void optionallyWriteMultimodalField(XContentBuilder builder) throws IOException; + public JinaAIServiceSettings getCommonSettings() { return commonSettings; } @@ -231,7 +225,7 @@ public DenseVectorFieldMapper.ElementType elementType() { @Override public boolean isMultimodal() { - return multimodalModel != null ? multimodalModel : getDefaultMultimodal(); + return multimodalModel; } @Override @@ -263,9 +257,7 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil builder.field(SIMILARITY, similarity); } - if (multimodalModel != null) { - builder.field(MULTIMODAL_MODEL, multimodalModel); - } + optionallyWriteMultimodalField(builder); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettings.java index 2d3e97bd09ae3..43be6b00fd591 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettings.java @@ -10,18 +10,28 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; import java.io.IOException; import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MULTIMODAL_MODEL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; public class JinaAIEmbeddingServiceSettings extends BaseJinaAIEmbeddingsServiceSettings { public static final String NAME = "jinaai_multimodal_embedding_service_settings"; + public static final boolean DEFAULT_MULTIMODAL_MODEL = true; public static JinaAIEmbeddingServiceSettings fromMap(Map map, ConfigurationParseContext context) { - return (JinaAIEmbeddingServiceSettings) BaseJinaAIEmbeddingsServiceSettings.fromMap(map, TaskType.EMBEDDING, context); + return BaseJinaAIEmbeddingsServiceSettings.fromMap( + map, + context, + (m, v) -> Objects.requireNonNullElse(removeAsType(m, MULTIMODAL_MODEL, Boolean.class, v), DEFAULT_MULTIMODAL_MODEL), + JinaAIEmbeddingServiceSettings::new + ); } public JinaAIEmbeddingServiceSettings( @@ -31,7 +41,7 @@ public JinaAIEmbeddingServiceSettings( @Nullable Integer maxInputTokens, @Nullable JinaAIEmbeddingType embeddingType, boolean dimensionsSetByUser, - @Nullable Boolean multimodalModel + boolean multimodalModel ) { super(commonSettings, similarity, dimensions, maxInputTokens, embeddingType, dimensionsSetByUser, multimodalModel); } @@ -40,11 +50,6 @@ public JinaAIEmbeddingServiceSettings(StreamInput in) throws IOException { super(in); } - @Override - public boolean getDefaultMultimodal() { - return true; - } - @Override public BaseJinaAIEmbeddingsServiceSettings update(SimilarityMeasure similarity, Integer dimensions) { return new JinaAIEmbeddingServiceSettings( @@ -58,6 +63,11 @@ public BaseJinaAIEmbeddingsServiceSettings update(SimilarityMeasure similarity, ); } + @Override + protected void optionallyWriteMultimodalField(XContentBuilder builder) throws IOException { + builder.field(MULTIMODAL_MODEL, isMultimodal()); + } + @Override public String getWriteableName() { return NAME; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java index 4c6a4751e60ab..fb99e5077050c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java @@ -52,7 +52,7 @@ public JinaAIEmbeddingsModel( ) { this( inferenceId, - BaseJinaAIEmbeddingsServiceSettings.fromMap(serviceSettings, taskType, context), + createServiceSettings(serviceSettings, taskType, context), JinaAIEmbeddingsTaskSettings.fromMap(taskSettings), chunkingSettings, DefaultSecretSettings.fromMap(secrets), @@ -107,4 +107,17 @@ public DefaultSecretSettings getSecretSettings() { public ExecutableAction accept(JinaAIActionVisitor visitor, Map taskSettings) { return visitor.create(this, taskSettings); } + + private static BaseJinaAIEmbeddingsServiceSettings createServiceSettings( + Map serviceSettings, + TaskType taskType, + ConfigurationParseContext context + ) { + return switch (taskType) { + case TEXT_EMBEDDING -> JinaAITextEmbeddingServiceSettings.fromMap(serviceSettings, context); + case EMBEDDING -> JinaAIEmbeddingServiceSettings.fromMap(serviceSettings, context); + // Should not be possible + default -> throw new IllegalArgumentException(); + }; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettings.java index 9ec949be0bc29..88d3e2e350d91 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettings.java @@ -10,7 +10,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; @@ -20,11 +20,30 @@ public class JinaAITextEmbeddingServiceSettings extends BaseJinaAIEmbeddingsServiceSettings { /** * This name is a holdover from before the introduction of {@link JinaAIEmbeddingServiceSettings} to support multimodal embeddings + * This name cannot be changed due to backwards compatibility, but it should be 'jinaai_text_embedding_service_settings' */ public static final String NAME = "jinaai_embeddings_service_settings"; + public static final boolean DEFAULT_MULTIMODAL_MODEL = false; public static JinaAITextEmbeddingServiceSettings fromMap(Map map, ConfigurationParseContext context) { - return (JinaAITextEmbeddingServiceSettings) BaseJinaAIEmbeddingsServiceSettings.fromMap(map, TaskType.TEXT_EMBEDDING, context); + return BaseJinaAIEmbeddingsServiceSettings.fromMap( + map, + context, + (m, v) -> DEFAULT_MULTIMODAL_MODEL, + JinaAITextEmbeddingServiceSettings::new + ); + } + + private JinaAITextEmbeddingServiceSettings( + JinaAIServiceSettings commonServiceSettings, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dims, + @Nullable Integer maxInputTokens, + @Nullable JinaAIEmbeddingType embeddingTypes, + boolean dimensionsSetByUser, + boolean multimodalModel + ) { + super(commonServiceSettings, similarity, dims, maxInputTokens, embeddingTypes, dimensionsSetByUser, DEFAULT_MULTIMODAL_MODEL); } public JinaAITextEmbeddingServiceSettings( @@ -35,18 +54,13 @@ public JinaAITextEmbeddingServiceSettings( @Nullable JinaAIEmbeddingType embeddingTypes, boolean dimensionsSetByUser ) { - super(commonServiceSettings, similarity, dims, maxInputTokens, embeddingTypes, dimensionsSetByUser, null); + this(commonServiceSettings, similarity, dims, maxInputTokens, embeddingTypes, dimensionsSetByUser, DEFAULT_MULTIMODAL_MODEL); } public JinaAITextEmbeddingServiceSettings(StreamInput in) throws IOException { super(in); } - @Override - public boolean getDefaultMultimodal() { - return false; - } - @Override public BaseJinaAIEmbeddingsServiceSettings update(SimilarityMeasure similarity, Integer dimensions) { return new JinaAITextEmbeddingServiceSettings( @@ -59,6 +73,11 @@ public BaseJinaAIEmbeddingsServiceSettings update(SimilarityMeasure similarity, ); } + @Override + protected void optionallyWriteMultimodalField(XContentBuilder builder) { + // Do not include the multimodal_model field for text_embedding, because it is always false + } + @Override public String getWriteableName() { return NAME; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index 424553d0a39e8..5554a6b2eca94 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -1074,7 +1074,7 @@ private void testInfer_TextEmbedding_Get_Response(InputType inputType, String ex apiKey, dimensions, TEXT_EMBEDDING, - null + false ); PlainActionFuture listener = new PlainActionFuture<>(); List input = List.of("abc"); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettingsTests.java index 1d5a096af2ed7..c67d4c158fab5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettingsTests.java @@ -7,227 +7,23 @@ package org.elasticsearch.xpack.inference.services.jinaai.embeddings; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; -import org.elasticsearch.xpack.inference.services.ServiceFields; -import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import java.util.HashMap; import java.util.Map; -import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; -import static org.elasticsearch.xpack.inference.TaskTypeTests.randomEmbeddingTaskType; -import static org.elasticsearch.xpack.inference.Utils.randomSimilarityMeasure; -import static org.elasticsearch.xpack.inference.services.ConfigurationParseContext.PERSISTENT; -import static org.elasticsearch.xpack.inference.services.ConfigurationParseContext.REQUEST; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS_SET_BY_USER; import static org.elasticsearch.xpack.inference.services.ServiceFields.EMBEDDING_TYPE; import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; -import static org.elasticsearch.xpack.inference.services.ServiceFields.MULTIMODAL_MODEL; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.updateEmbeddingDetails; -import static org.hamcrest.Matchers.anEmptyMap; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.sameInstance; public class BaseJinaAIEmbeddingsServiceSettingsTests extends ESTestCase { - public void testFromMap_parsesAllFields_textEmbedding_requestContext() { - testFromMap_parsesAllFields(TEXT_EMBEDDING, REQUEST, randomNonNegativeInt()); - } - - public void testFromMap_parsesAllFields_embedding_requestContext() { - testFromMap_parsesAllFields(TaskType.EMBEDDING, REQUEST, randomNonNegativeInt()); - } - - public void testFromMap_parsesAllFields_textEmbedding_persistentContext() { - testFromMap_parsesAllFields(TEXT_EMBEDDING, PERSISTENT, randomNonNegativeInt()); - } - - public void testFromMap_parsesAllFields_embedding_persistentContext() { - testFromMap_parsesAllFields(TaskType.EMBEDDING, PERSISTENT, randomNonNegativeInt()); - } - - public void testFromMap_parsesAllFields_textEmbedding_requestContext_dimensionsNotSet() { - testFromMap_parsesAllFields(TEXT_EMBEDDING, REQUEST, null); - } - - public void testFromMap_parsesAllFields_embedding_requestContext_dimensionsNotSet() { - testFromMap_parsesAllFields(TaskType.EMBEDDING, REQUEST, null); - } - - private void testFromMap_parsesAllFields(TaskType taskType, ConfigurationParseContext parseContext, Integer dimensions) { - var similarity = randomSimilarityMeasure(); - var maxInputTokens = randomNonNegativeInt(); - var model = randomAlphanumericOfLength(8); - var embeddingType = randomFrom(JinaAIEmbeddingType.values()); - var requestsPerMinute = randomNonNegativeInt(); - var settingsMap = getMapOfCommonEmbeddingSettings( - model, - similarity, - dimensions, - null, - maxInputTokens, - embeddingType, - requestsPerMinute - ); - - var dimensionsSetByUser = dimensions != null; - if (parseContext == PERSISTENT) { - dimensionsSetByUser = randomBoolean(); - settingsMap.put(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); - } - - var multimodalModel = false; - if (taskType == TaskType.EMBEDDING) { - multimodalModel = randomBoolean(); - settingsMap.put(MULTIMODAL_MODEL, multimodalModel); - } - - var serviceSettings = BaseJinaAIEmbeddingsServiceSettings.fromMap(settingsMap, taskType, parseContext); - - assertThat(settingsMap, anEmptyMap()); - - assertServiceSettings( - serviceSettings, - taskType, - model, - requestsPerMinute, - similarity, - dimensions, - maxInputTokens, - embeddingType, - dimensionsSetByUser, - multimodalModel - ); - } - - public void testFromMap_doesNotRemoveMultimodalModelField_whenTaskTypeIsTextEmbedding() { - var settingsMap = getMapOfMinimalEmbeddingSettings(randomAlphanumericOfLength(8)); - - settingsMap.put(MULTIMODAL_MODEL, randomBoolean()); - - var settings = BaseJinaAIEmbeddingsServiceSettings.fromMap( - settingsMap, - TEXT_EMBEDDING, - randomFrom(ConfigurationParseContext.values()) - ); - - assertThat(settingsMap.get(MULTIMODAL_MODEL), notNullValue()); - assertThat(settings.isMultimodal(), is(false)); - } - - private static void assertServiceSettings( - BaseJinaAIEmbeddingsServiceSettings serviceSettings, - TaskType taskType, - String model, - Integer requestsPerMinute, - SimilarityMeasure similarity, - Integer dimensions, - Integer maxInputTokens, - JinaAIEmbeddingType embeddingType, - boolean dimensionsSetByUser, - Boolean multimodalModel - ) { - BaseJinaAIEmbeddingsServiceSettings expectedSettings; - if (taskType == TEXT_EMBEDDING) { - expectedSettings = new JinaAITextEmbeddingServiceSettings( - new JinaAIServiceSettings(model, new RateLimitSettings(requestsPerMinute)), - similarity, - dimensions, - maxInputTokens, - embeddingType, - dimensionsSetByUser - ); - } else if (taskType == TaskType.EMBEDDING) { - expectedSettings = new JinaAIEmbeddingServiceSettings( - new JinaAIServiceSettings(model, new RateLimitSettings(requestsPerMinute)), - similarity, - dimensions, - maxInputTokens, - embeddingType, - dimensionsSetByUser, - multimodalModel - ); - } else { - throw new IllegalArgumentException("Invalid taskType " + taskType); - } - - assertThat(serviceSettings, is(expectedSettings)); - } - - public void testFromMap_withInvalidSimilarity_throwsError() { - var similarity = "by_size"; - var thrownException = expectThrows( - ValidationException.class, - () -> BaseJinaAIEmbeddingsServiceSettings.fromMap( - new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model", SIMILARITY, similarity)), - randomEmbeddingTaskType(), - randomFrom(ConfigurationParseContext.values()) - ) - ); - - assertThat( - thrownException.getMessage(), - is( - "Validation Failed: 1: [service_settings] Invalid value [by_size] received. [similarity] " - + "must be one of [cosine, dot_product, l2_norm];" - ) - ); - } - - public void testFromMap_nonPositiveDimensions_throwsError() { - var dimensions = randomIntBetween(-5, 0); - var thrownException = expectThrows( - ValidationException.class, - () -> BaseJinaAIEmbeddingsServiceSettings.fromMap( - new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model", DIMENSIONS, dimensions)), - randomEmbeddingTaskType(), - randomFrom(ConfigurationParseContext.values()) - ) - ); - - assertThat( - thrownException.getMessage(), - is( - Strings.format( - "Validation Failed: 1: [service_settings] Invalid value [%d]. [%s] must be a positive integer;", - dimensions, - DIMENSIONS - ) - ) - ); - } - - public void testFromMap_withInvalidEmbeddingType_throwsError() { - var embeddingType = "bad_value"; - var thrownException = expectThrows( - ValidationException.class, - () -> BaseJinaAIEmbeddingsServiceSettings.fromMap( - new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model", EMBEDDING_TYPE, embeddingType)), - randomEmbeddingTaskType(), - randomFrom(ConfigurationParseContext.values()) - ) - ); - - assertThat( - thrownException.getMessage(), - is( - "Validation Failed: 1: [service_settings] Invalid value [bad_value] received. [embedding_type] " - + "must be one of [binary, bit, float];" - ) - ); - } - public void testUpdateEmbeddingDetails_returnsSameInstance_whenEmbeddingSizeAndSimilarityAreSame() { var settings = randomBoolean() ? JinaAIEmbeddingServiceSettingsTests.createRandomWithNoNullValues() diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettingsTests.java index d152fb3d746cc..e2bea367aa8f6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettingsTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; @@ -37,13 +36,16 @@ import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace; import static org.elasticsearch.inference.EmbeddingRequest.JINA_AI_EMBEDDING_TASK_ADDED; import static org.elasticsearch.xpack.inference.Utils.randomSimilarityMeasure; +import static org.elasticsearch.xpack.inference.services.ConfigurationParseContext.PERSISTENT; +import static org.elasticsearch.xpack.inference.services.ConfigurationParseContext.REQUEST; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS_SET_BY_USER; import static org.elasticsearch.xpack.inference.services.ServiceFields.EMBEDDING_TYPE; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.MULTIMODAL_MODEL; import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED; import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED; import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettingsTests.getMapOfCommonEmbeddingSettings; -import static org.elasticsearch.xpack.inference.services.settings.RateLimitSettings.REQUESTS_PER_MINUTE_FIELD; import static org.hamcrest.Matchers.is; public class JinaAIEmbeddingServiceSettingsTests extends AbstractBWCWireSerializationTestCase { @@ -90,36 +92,47 @@ public static JinaAIEmbeddingServiceSettings createRandomWithNoNullValues() { ); } - public void testFromMap_Request_CreatesSettingsCorrectly() { - var similarity = SimilarityMeasure.DOT_PRODUCT; - var dimensions = 1536; - var maxInputTokens = 512; - var model = "model"; + public void testFromMap_persistentContext_createsSettingsCorrectly() { + testFromMap(randomNonNegativeInt(), PERSISTENT); + } + + public void testFromMap_requestContext_createsSettingsCorrectly() { + testFromMap(randomNonNegativeInt(), REQUEST); + } + + public void testFromMap_requestContext_nullDimensions_createsSettingsCorrectly() { + testFromMap(null, REQUEST); + } + + private static void testFromMap(Integer dimensions, ConfigurationParseContext parseContext) { + var similarity = randomSimilarityMeasure(); + var maxInputTokens = randomNonNegativeInt(); + var model = randomAlphanumericOfLength(8); var embeddingType = randomFrom(JinaAIEmbeddingType.values()); - var requestsPerMinute = 1234; - var multiModalModel = false; - var serviceSettings = JinaAIEmbeddingServiceSettings.fromMap( - new HashMap<>( - Map.of( - ServiceFields.SIMILARITY, - similarity.toString(), - ServiceFields.DIMENSIONS, - dimensions, - ServiceFields.MAX_INPUT_TOKENS, - maxInputTokens, - ServiceFields.MODEL_ID, - model, - EMBEDDING_TYPE, - embeddingType.toString(), - RateLimitSettings.FIELD_NAME, - new HashMap<>(Map.of(REQUESTS_PER_MINUTE_FIELD, requestsPerMinute)), - MULTIMODAL_MODEL, - multiModalModel - ) - ), - ConfigurationParseContext.REQUEST + var requestsPerMinute = randomNonNegativeInt(); + var multimodalModel = randomBoolean(); + var settingsMap = getMapOfCommonEmbeddingSettings( + model, + similarity, + dimensions, + null, + maxInputTokens, + embeddingType, + requestsPerMinute ); + settingsMap.put(MULTIMODAL_MODEL, multimodalModel); + + boolean dimensionsSetByUser; + if (parseContext == REQUEST) { + dimensionsSetByUser = dimensions != null; + } else { + dimensionsSetByUser = randomBoolean(); + settingsMap.put(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + } + + var serviceSettings = JinaAIEmbeddingServiceSettings.fromMap(settingsMap, parseContext); + assertThat( serviceSettings, is( @@ -129,8 +142,8 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { dimensions, maxInputTokens, embeddingType, - true, - multiModalModel + dimensionsSetByUser, + multimodalModel ) ) ); @@ -140,7 +153,6 @@ public void testFromMap_onlyRequiredFields() { var model = "model"; var serviceSettings = JinaAIEmbeddingServiceSettings.fromMap( new HashMap<>(Map.of(MODEL_ID, model)), - TaskType.EMBEDDING, randomFrom(ConfigurationParseContext.values()) ); @@ -166,8 +178,7 @@ public void testFromMap_InvalidSimilarity_ThrowsError() { ValidationException.class, () -> JinaAIEmbeddingServiceSettings.fromMap( new HashMap<>(Map.of(MODEL_ID, "model", ServiceFields.SIMILARITY, similarity)), - TaskType.EMBEDDING, - ConfigurationParseContext.PERSISTENT + randomFrom(ConfigurationParseContext.values()) ) ); @@ -183,14 +194,35 @@ public void testFromMap_InvalidSimilarity_ThrowsError() { ); } + public void testFromMap_nonPositiveDimensions_throwsError() { + var dimensions = randomIntBetween(-5, 0); + var thrownException = expectThrows( + ValidationException.class, + () -> JinaAIEmbeddingServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model", DIMENSIONS, dimensions)), + randomFrom(ConfigurationParseContext.values()) + ) + ); + + assertThat( + thrownException.getMessage(), + is( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [%s] must be a positive integer;", + dimensions, + DIMENSIONS + ) + ) + ); + } + public void testFromMap_InvalidEmbeddingType_ThrowsError() { var embeddingType = "invalid"; var thrownException = expectThrows( ValidationException.class, () -> JinaAIEmbeddingServiceSettings.fromMap( new HashMap<>(Map.of(MODEL_ID, "model", EMBEDDING_TYPE, embeddingType)), - TaskType.EMBEDDING, - ConfigurationParseContext.PERSISTENT + randomFrom(ConfigurationParseContext.values()) ) ); @@ -296,13 +328,13 @@ protected JinaAIEmbeddingServiceSettings mutateInstance(JinaAIEmbeddingServiceSe @Override protected JinaAIEmbeddingServiceSettings mutateInstanceForVersion(JinaAIEmbeddingServiceSettings instance, TransportVersion version) { - Boolean multimodalModel = instance.isMultimodal(); + boolean multimodalModel = instance.isMultimodal(); boolean dimensionsSetByUser = instance.dimensionsSetByUser(); JinaAIEmbeddingType embeddingType = instance.getEmbeddingType(); - // default to null for multimodal if node is on a version before embedding task support was added + // default to false for multimodal if node is on a version before embedding task support was added if (version.supports(JINA_AI_EMBEDDING_TASK_ADDED) == false) { - multimodalModel = null; + multimodalModel = false; } // default to false for dimensionsSetByUser if node is on a version before support for setting embedding dimensions was added if (version.supports(JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED) == false) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java index 7232a576cb508..9b8d5a0cfa784 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java @@ -20,6 +20,7 @@ import java.util.Map; +import static org.elasticsearch.inference.TaskType.EMBEDDING; import static org.elasticsearch.xpack.inference.TaskTypeTests.randomEmbeddingTaskType; import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIService.VALID_INPUT_TYPE_VALUES; import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap; @@ -109,7 +110,7 @@ public static JinaAIEmbeddingsModel createTextEmbeddingModel(String url, String * Returns a model with empty task settings, service settings and chunking settings, using the {@link TaskType#EMBEDDING} task type */ public static JinaAIEmbeddingsModel createEmbeddingModel(String url, String modelName, String apiKey) { - return createModel(url, modelName, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, apiKey, TaskType.EMBEDDING); + return createModel(url, modelName, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, apiKey, EMBEDDING); } /** @@ -122,7 +123,7 @@ public static JinaAIEmbeddingsModel createModel( String apiKey, TaskType taskType ) { - var serviceSettings = getEmbeddingServiceSettings(modelName, null, null, null, null, null, false, taskType, null); + var serviceSettings = getEmbeddingServiceSettings(modelName, null, null, null, null, null, false, taskType, taskType == EMBEDDING); return createModel(url, serviceSettings, taskSettings, null, apiKey, taskType); } @@ -137,7 +138,7 @@ public static JinaAIEmbeddingsModel createModel( String apiKey, TaskType taskType ) { - var serviceSettings = getEmbeddingServiceSettings(modelName, null, null, null, null, null, false, taskType, null); + var serviceSettings = getEmbeddingServiceSettings(modelName, null, null, null, null, null, false, taskType, taskType == EMBEDDING); return createModel(url, serviceSettings, taskSettings, chunkingSettings, apiKey, taskType); } @@ -152,7 +153,7 @@ public static JinaAIEmbeddingsModel createModel( String apiKey, @Nullable Integer dimensions, TaskType taskType, - @Nullable Boolean multimodalModel + boolean multimodalModel ) { var serviceSettings = getEmbeddingServiceSettings( modelName, @@ -181,7 +182,7 @@ public static JinaAIEmbeddingsModel createModel( String apiKey, boolean dimensionsSetByUser, TaskType taskType, - @Nullable Boolean multimodalModel + boolean multimodalModel ) { var serviceSettings = getEmbeddingServiceSettings( modelName, @@ -225,7 +226,7 @@ public static BaseJinaAIEmbeddingsServiceSettings getEmbeddingServiceSettings( @Nullable JinaAIEmbeddingType embeddingType, boolean dimensionsSetByUser, TaskType taskType, - @Nullable Boolean multimodalModel + boolean multimodalModel ) { if (taskType == TaskType.TEXT_EMBEDDING) { return new JinaAITextEmbeddingServiceSettings( @@ -236,7 +237,7 @@ public static BaseJinaAIEmbeddingsServiceSettings getEmbeddingServiceSettings( embeddingType, dimensionsSetByUser ); - } else if (taskType == TaskType.EMBEDDING) { + } else if (taskType == EMBEDDING) { return new JinaAIEmbeddingServiceSettings( new JinaAIServiceSettings(modelName, rateLimitSettings), similarity, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettingsTests.java index d3e1ae8ae4cb9..c1dec2ded9f3a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettingsTests.java @@ -35,16 +35,20 @@ import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace; import static org.elasticsearch.xpack.inference.Utils.randomSimilarityMeasure; +import static org.elasticsearch.xpack.inference.services.ConfigurationParseContext.PERSISTENT; +import static org.elasticsearch.xpack.inference.services.ConfigurationParseContext.REQUEST; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS_SET_BY_USER; import static org.elasticsearch.xpack.inference.services.ServiceFields.EMBEDDING_TYPE; -import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MULTIMODAL_MODEL; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; -import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED; import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED; import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettingsTests.getMapOfCommonEmbeddingSettings; -import static org.elasticsearch.xpack.inference.services.settings.RateLimitSettings.REQUESTS_PER_MINUTE_FIELD; +import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; public class JinaAITextEmbeddingServiceSettingsTests extends AbstractBWCWireSerializationTestCase { @@ -86,104 +90,44 @@ public static JinaAITextEmbeddingServiceSettings createRandomWithNoNullValues() ); } - public void testFromMap_Request_CreatesSettingsCorrectly() { - var similarity = SimilarityMeasure.DOT_PRODUCT; - var dimensions = 1536; - var maxInputTokens = 512; - var model = "model"; - var embeddingType = randomFrom(JinaAIEmbeddingType.values()); - var requestsPerMinute = 1234; - var serviceSettings = JinaAITextEmbeddingServiceSettings.fromMap( - new HashMap<>( - Map.of( - ServiceFields.SIMILARITY, - similarity.toString(), - ServiceFields.DIMENSIONS, - dimensions, - ServiceFields.MAX_INPUT_TOKENS, - maxInputTokens, - ServiceFields.MODEL_ID, - model, - EMBEDDING_TYPE, - embeddingType.toString(), - RateLimitSettings.FIELD_NAME, - new HashMap<>(Map.of(REQUESTS_PER_MINUTE_FIELD, requestsPerMinute)) - ) - ), - ConfigurationParseContext.REQUEST - ); - - assertThat( - serviceSettings, - is( - new JinaAITextEmbeddingServiceSettings( - new JinaAIServiceSettings(model, new RateLimitSettings(requestsPerMinute)), - similarity, - dimensions, - maxInputTokens, - embeddingType, - true - ) - ) - ); + public void testFromMap_persistentContext_createsSettingsCorrectly() { + testFromMap(randomNonNegativeInt(), PERSISTENT); } - public void testFromMap_Request_DimensionsSetByUser_IsFalse_WhenDimensionsAreNotPresent() { - var url = "https://www.abc.com"; - var model = "model"; - var serviceSettings = JinaAITextEmbeddingServiceSettings.fromMap( - new HashMap<>(Map.of(URL, url, ServiceFields.MODEL_ID, model)), - ConfigurationParseContext.REQUEST - ); + public void testFromMap_requestContext_createsSettingsCorrectly() { + testFromMap(randomNonNegativeInt(), REQUEST); + } - assertThat( - serviceSettings, - is( - new JinaAITextEmbeddingServiceSettings( - new JinaAIServiceSettings(model, null), - null, - null, - null, - JinaAIEmbeddingType.FLOAT, - false - ) - ) - ); + public void testFromMap_requestContext_nullDimensions_createsSettingsCorrectly() { + testFromMap(null, REQUEST); } - public void testFromMap_Persistent_CreatesSettingsCorrectly() { - var url = "https://www.abc.com"; + private static void testFromMap(Integer dimensions, ConfigurationParseContext parseContext) { var similarity = randomSimilarityMeasure(); - var dimensions = 1536; - var maxInputTokens = 512; - var model = "model"; + var maxInputTokens = randomNonNegativeInt(); + var model = randomAlphanumericOfLength(8); var embeddingType = randomFrom(JinaAIEmbeddingType.values()); - var requestsPerMinute = 1234; - var dimensionsSetByUser = randomBoolean(); - var serviceSettings = JinaAITextEmbeddingServiceSettings.fromMap( - new HashMap<>( - Map.of( - URL, - url, - SIMILARITY, - similarity.toString(), - DIMENSIONS, - dimensions, - MAX_INPUT_TOKENS, - maxInputTokens, - ServiceFields.MODEL_ID, - model, - EMBEDDING_TYPE, - embeddingType.toString(), - RateLimitSettings.FIELD_NAME, - new HashMap<>(Map.of(REQUESTS_PER_MINUTE_FIELD, requestsPerMinute)), - ServiceFields.DIMENSIONS_SET_BY_USER, - dimensionsSetByUser - ) - ), - ConfigurationParseContext.PERSISTENT + var requestsPerMinute = randomNonNegativeInt(); + var settingsMap = getMapOfCommonEmbeddingSettings( + model, + similarity, + dimensions, + null, + maxInputTokens, + embeddingType, + requestsPerMinute ); + boolean dimensionsSetByUser; + if (parseContext == REQUEST) { + dimensionsSetByUser = dimensions != null; + } else { + dimensionsSetByUser = randomBoolean(); + settingsMap.put(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + } + + var serviceSettings = JinaAITextEmbeddingServiceSettings.fromMap(settingsMap, parseContext); + assertThat( serviceSettings, is( @@ -199,25 +143,11 @@ public void testFromMap_Persistent_CreatesSettingsCorrectly() { ); } - public void testFromMap_WithModelId() { - var similarity = SimilarityMeasure.DOT_PRODUCT; - var dimensions = 1536; - var maxInputTokens = 512; + public void testFromMap_onlyRequiredFields() { var model = "model"; var serviceSettings = JinaAITextEmbeddingServiceSettings.fromMap( - new HashMap<>( - Map.of( - ServiceFields.SIMILARITY, - similarity.toString(), - DIMENSIONS, - dimensions, - MAX_INPUT_TOKENS, - maxInputTokens, - ServiceFields.MODEL_ID, - model - ) - ), - ConfigurationParseContext.REQUEST + new HashMap<>(Map.of(MODEL_ID, model)), + randomFrom(ConfigurationParseContext.values()) ); assertThat( @@ -225,49 +155,33 @@ public void testFromMap_WithModelId() { is( new JinaAITextEmbeddingServiceSettings( new JinaAIServiceSettings(model, null), - similarity, - dimensions, - maxInputTokens, + null, + null, + null, JinaAIEmbeddingType.FLOAT, - true + false ) ) ); } - public void testFromMap_WithEmbeddingType() { - var similarity = SimilarityMeasure.DOT_PRODUCT; - var dimensions = 1536; - var maxInputTokens = 512; - var model = "model"; - var serviceSettings = JinaAITextEmbeddingServiceSettings.fromMap( - new HashMap<>( - Map.of( - ServiceFields.SIMILARITY, - similarity.toString(), - DIMENSIONS, - dimensions, - MAX_INPUT_TOKENS, - maxInputTokens, - ServiceFields.MODEL_ID, - model, - EMBEDDING_TYPE, - JinaAIEmbeddingType.BIT.toString() - ) - ), - ConfigurationParseContext.REQUEST + public void testFromMap_InvalidEmbeddingType_ThrowsError() { + var embeddingType = "invalid"; + var thrownException = expectThrows( + ValidationException.class, + () -> JinaAITextEmbeddingServiceSettings.fromMap( + new HashMap<>(Map.of(MODEL_ID, "model", EMBEDDING_TYPE, embeddingType)), + randomFrom(ConfigurationParseContext.values()) + ) ); assertThat( - serviceSettings, + thrownException.getMessage(), is( - new JinaAITextEmbeddingServiceSettings( - new JinaAIServiceSettings(model, null), - similarity, - dimensions, - maxInputTokens, - JinaAIEmbeddingType.BIT, - true + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%s] received. [embedding_type] " + + "must be one of [binary, bit, float];", + embeddingType ) ) ); @@ -314,6 +228,15 @@ public void testFromMap_nonPositiveDimensions_ThrowsError() { ); } + public void testFromMap_doesNotRemoveMultimodalModelField() { + var model = "model"; + HashMap settingsMap = new HashMap<>(Map.of(MODEL_ID, model, MULTIMODAL_MODEL, true)); + var serviceSettings = JinaAITextEmbeddingServiceSettings.fromMap(settingsMap, randomFrom(ConfigurationParseContext.values())); + + assertThat(serviceSettings.isMultimodal(), is(false)); + assertThat(settingsMap, not(anEmptyMap())); + } + public void testToXContent_WritesAllValues() throws IOException { var modelName = randomAlphanumericOfLength(10); var requestsPerMinute = randomNonNegativeInt(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java index 317f52d970200..3847f2b5b09e4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java @@ -331,7 +331,7 @@ public void testXContent_multimodal_WritesNoOptionalFields_WhenTheyAreNotDefined var entity = new JinaAIEmbeddingsRequestEntity( List.of(new InferenceStringGroup(textInput), new InferenceStringGroup(new InferenceString(IMAGE, imageInput))), null, - createModel(null, modelName, null, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "apiKey", null, TaskType.EMBEDDING, null) + createModel(null, modelName, null, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "apiKey", null, TaskType.EMBEDDING, true) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestTests.java index 5567ce83a9af9..bff4ea69eb3bf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestTests.java @@ -53,7 +53,7 @@ public void testCreateRequest_AllOptionsDefined_textEmbedding() throws IOExcepti apiKey, dimensions, TaskType.TEXT_EMBEDDING, - null + false ) ); @@ -164,7 +164,7 @@ public void testCreateRequest_TaskSettingsInputType() throws IOException { apiKey, dimensions, TaskType.TEXT_EMBEDDING, - null + false ) ); From b0e3e519efb6819712642ea340d4ed5e34bbd424 Mon Sep 17 00:00:00 2001 From: donalevans Date: Fri, 9 Jan 2026 15:50:15 -0800 Subject: [PATCH 4/5] Validate content objects do not contain multiple items --- .../inference/EmbeddingRequest.java | 2 +- .../inference/InferenceStringGroup.java | 23 ++++++- .../inference/InferenceStringGroupTests.java | 62 ++++++++++++++++++- .../inference/services/SenderService.java | 37 ++++++++++- .../JinaAIEmbeddingsRequestEntity.java | 15 +++-- .../services/jinaai/JinaAIServiceTests.java | 46 ++++++++++++++ 6 files changed, 169 insertions(+), 16 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java b/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java index 66722d85ec06e..637576ef29870 100644 --- a/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java @@ -77,7 +77,7 @@ public record EmbeddingRequest(List inputs, InputType inpu public static final TransportVersion JINA_AI_EMBEDDING_TASK_ADDED = TransportVersion.fromName("jina_ai_embedding_task_added"); - private static final String INPUT_FIELD = "input"; + public static final String INPUT_FIELD = "input"; private static final String INPUT_TYPE_FIELD = "input_type"; @SuppressWarnings("unchecked") diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java b/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java index 134c8082259cd..9db304e58a987 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java @@ -40,7 +40,7 @@ * */ public final class InferenceStringGroup implements Writeable, ToXContentObject { - private static final String CONTENT_FIELD = "content"; + public static final String CONTENT_FIELD = "content"; @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -84,6 +84,10 @@ public boolean containsNonTextEntry() { return containsNonTextEntry; } + public boolean containsMultipleInferenceStrings() { + return inferenceStrings.size() > 1; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeCollection(inferenceStrings); @@ -166,6 +170,23 @@ public static boolean containsNonTextEntry(List inferenceS return inferenceStringGroups.stream().anyMatch(InferenceStringGroup::containsNonTextEntry); } + /** + * Method used to determine if a list of {@link InferenceStringGroup} contains any with more than one {@link InferenceString} in them + * + * @param inferenceStringGroups the list of {@link InferenceStringGroup} to check + * @return the index of the first {@link InferenceStringGroup} found to contain more than one {@link InferenceString}, or null if no + * elements in the list contain more than one {@link InferenceString} + */ + public static Integer indexContainingMultipleInferenceStrings(List inferenceStringGroups) { + for (int i = 0; i < inferenceStringGroups.size(); i++) { + InferenceStringGroup inferenceStringGroup = inferenceStringGroups.get(i); + if (inferenceStringGroup.containsMultipleInferenceStrings()) { + return i; + } + } + return null; + } + @Override public boolean equals(Object obj) { if (obj == this) return true; diff --git a/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java b/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java index 8ec4de041f3c6..b68af794921c9 100644 --- a/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java +++ b/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java @@ -14,16 +14,21 @@ import org.elasticsearch.inference.InferenceString.DataFormat; import org.elasticsearch.inference.InferenceString.DataType; import org.elasticsearch.test.AbstractBWCSerializationTestCase; +import org.elasticsearch.xcontent.XContentParseException; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.json.JsonXContent; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import static org.elasticsearch.inference.InferenceString.DataType.TEXT; import static org.elasticsearch.inference.InferenceStringGroup.containsNonTextEntry; +import static org.elasticsearch.inference.InferenceStringGroup.indexContainingMultipleInferenceStrings; import static org.elasticsearch.inference.InferenceStringGroup.toInferenceStringList; import static org.elasticsearch.inference.InferenceStringGroup.toStringList; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; public class InferenceStringGroupTests extends AbstractBWCSerializationTestCase { @@ -31,15 +36,40 @@ public class InferenceStringGroupTests extends AbstractBWCSerializationTestCase< public void testStringConstructor() { String stringValue = "a string"; var input = new InferenceStringGroup(stringValue); - assertThat(input.inferenceStrings(), contains(new InferenceString(DataType.TEXT, DataFormat.TEXT, stringValue))); + assertThat(input.inferenceStrings(), contains(new InferenceString(TEXT, DataFormat.TEXT, stringValue))); assertThat(input.containsNonTextEntry(), is(false)); + assertThat(input.containsMultipleInferenceStrings(), is(false)); } - public void testSingleArgumentConstructor() { + public void testSingleInferenceStringConstructor() { InferenceString inferenceString = new InferenceString(DataType.IMAGE, DataFormat.BASE64, "a string"); var input = new InferenceStringGroup(inferenceString); assertThat(input.inferenceStrings(), contains(inferenceString)); assertThat(input.containsNonTextEntry(), is(true)); + assertThat(input.containsMultipleInferenceStrings(), is(false)); + } + + public void testInferenceStringListConstructor() { + InferenceString inferenceString1 = new InferenceString(DataType.IMAGE, DataFormat.BASE64, "a string"); + InferenceString inferenceString2 = new InferenceString(TEXT, DataFormat.TEXT, "a string"); + var input = new InferenceStringGroup(List.of(inferenceString1, inferenceString2)); + assertThat(input.inferenceStrings(), contains(inferenceString1, inferenceString2)); + assertThat(input.containsNonTextEntry(), is(true)); + assertThat(input.containsMultipleInferenceStrings(), is(true)); + } + + public void testParser_withEmptyContentObject_throws() throws IOException { + var requestJson = """ + { + "content": {} + } + """; + try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { + // Need to call nextToken() so that the parser is at the correct element + parser.nextToken(); + var exception = expectThrows(XContentParseException.class, () -> InferenceStringGroup.PARSER.apply(parser, null)); + assertThat(exception.getMessage(), containsString("[InferenceStringGroup] failed to parse field [content]")); + } } public void testValue_withMoreThanOneElement_throws() { @@ -86,7 +116,7 @@ public void testContainsNonTextEntry_withOnlyTextInputs() { } public void testContainsNonTextEntry_withNonTextInput() { - DataType nonTextDataType = randomValueOtherThan(DataType.TEXT, () -> randomFrom(DataType.values())); + DataType nonTextDataType = randomValueOtherThan(TEXT, () -> randomFrom(DataType.values())); var inputs = List.of( new InferenceStringGroup("string1"), new InferenceStringGroup(new InferenceString(nonTextDataType, "non text")) @@ -94,6 +124,32 @@ public void testContainsNonTextEntry_withNonTextInput() { assertThat(containsNonTextEntry(inputs), is(true)); } + public void testIndexContainingMultipleInferenceStrings_withSingleInferenceString() { + var inputs = getInputsList(); + assertThat(indexContainingMultipleInferenceStrings(inputs), is(null)); + } + + public void testIndexContainingMultipleInferenceStrings_withMultipleInferenceStrings() { + var inputs = getInputsList(); + + // Add an InferenceStringGroup with multiple InferenceStrings at a random point in the input list + var indexToAdd = randomIntBetween(0, inputs.size() - 1); + var multipleInferenceStrings = new InferenceStringGroup( + List.of(new InferenceString(TEXT, "a_string"), new InferenceString(TEXT, "a_string")) + ); + inputs.add(indexToAdd, multipleInferenceStrings); + assertThat(indexContainingMultipleInferenceStrings(inputs), is(indexToAdd)); + } + + private static ArrayList getInputsList() { + var listSize = randomIntBetween(1, 10); + var inputs = new ArrayList(listSize); + for (int i = 0; i < listSize; ++i) { + inputs.add(new InferenceStringGroup("a_string")); + } + return inputs; + } + @Override protected InferenceStringGroup mutateInstanceForVersion(InferenceStringGroup instance, TransportVersion version) { return instance; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 47b394eed3c39..60a4ee2cf48a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -21,6 +21,7 @@ import org.elasticsearch.inference.EmbeddingRequest; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; @@ -41,6 +42,7 @@ import java.util.Objects; import java.util.Set; +import static org.elasticsearch.inference.InferenceStringGroup.indexContainingMultipleInferenceStrings; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedEmbeddingOperation; public abstract class SenderService implements InferenceService { @@ -138,9 +140,38 @@ public void unifiedCompletionInfer( @Override public void embeddingInfer(Model model, EmbeddingRequest request, TimeValue timeout, ActionListener listener) { - SubscribableListener.newForked(this::init) - .andThen((embeddingInferListener) -> doEmbeddingInfer(model, request, timeout, embeddingInferListener)) - .addListener(listener); + SubscribableListener.newForked(this::init).andThen((embeddingInferListener) -> { + if (supportsMultipleItemsPerContent()) { + doEmbeddingInfer(model, request, timeout, embeddingInferListener); + } else { + var index = indexContainingMultipleInferenceStrings(request.inputs()); + if (index == null) { + doEmbeddingInfer(model, request, timeout, embeddingInferListener); + } else { + listener.onFailure( + new ElasticsearchStatusException( + Strings.format( + "Field [%1$s] must contain a single item for [%2$s] service. " + + "[%1$s] object with multiple items found at $.%3$s.%1$s[%4$d]", + InferenceStringGroup.CONTENT_FIELD, + name(), + EmbeddingRequest.INPUT_FIELD, + index + ), + RestStatus.BAD_REQUEST + ) + ); + } + } + }).addListener(listener); + } + + /** + * Override as necessary for services which support generating a single embedding vector from multiple inputs + * @return true if the service supports sending embedding requests where multiple inputs are used to generate a single embedding vector + */ + protected boolean supportsMultipleItemsPerContent() { + return false; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java index fdb0997bf129b..a5390d58e8b4b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java @@ -85,15 +85,14 @@ private void writeInputs(XContentBuilder builder) throws IOException { if (model.getServiceSettings().isMultimodal()) { builder.startArray(INPUT_FIELD); for (var inferenceStringGroup : input) { - for (var inferenceString : inferenceStringGroup.inferenceStrings()) { - builder.startObject(); - if (inferenceString.isText()) { - builder.field(INPUT_TEXT_FIELD, inferenceString.value()); - } else if (inferenceString.isImage()) { - builder.field(INPUT_IMAGE_FIELD, inferenceString.value()); - } - builder.endObject(); + var inferenceString = inferenceStringGroup.value(); + builder.startObject(); + if (inferenceString.isText()) { + builder.field(INPUT_TEXT_FIELD, inferenceString.value()); + } else if (inferenceString.isImage()) { + builder.field(INPUT_IMAGE_FIELD, inferenceString.value()); } + builder.endObject(); } builder.endArray(); } else { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index 5554a6b2eca94..8095d4ccb0f48 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -72,6 +72,7 @@ import org.junit.Before; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -79,6 +80,7 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.inference.InferenceString.DataFormat.BASE64; import static org.elasticsearch.inference.InferenceString.DataType.IMAGE; +import static org.elasticsearch.inference.InferenceString.DataType.TEXT; import static org.elasticsearch.inference.ModelConfigurations.SERVICE_SETTINGS; import static org.elasticsearch.inference.TaskType.EMBEDDING; import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; @@ -1786,6 +1788,50 @@ public void testEmbeddingInfer_returnsError_nonMultimodalModel_withNonTextInput( } } + public void testEmbeddingInfer_returnsError_multipleItemsInContentObject() throws IOException { + var model = JinaAIEmbeddingsModelTests.createEmbeddingModel(getUrl(webServer), "modelName", "apiKey"); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture listener = new PlainActionFuture<>(); + + var listSize = randomIntBetween(1, 10); + var inputs = new ArrayList(listSize); + for (int i = 0; i < listSize; ++i) { + inputs.add(new InferenceStringGroup("a_string")); + } + + // Add an InferenceStringGroup with multiple InferenceStrings at a random point in the input list + var indexToAdd = randomIntBetween(0, inputs.size() - 1); + var multipleInferenceStrings = new InferenceStringGroup( + List.of( + new InferenceString(TEXT, InferenceString.DataFormat.TEXT, "first_input"), + new InferenceString(IMAGE, BASE64, "second_input") + ) + ); + inputs.add(indexToAdd, multipleInferenceStrings); + service.embeddingInfer( + model, + new EmbeddingRequest(inputs, InputType.UNSPECIFIED, Map.of()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat( + thrownException.getMessage(), + is( + Strings.format( + "Field [content] must contain a single item for [jinaai] service. " + + "[content] object with multiple items found at $.input.content[%d]", + indexToAdd + ) + ) + ); + assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST)); + } + } + public void testEmbeddingInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); From 11ea8583d034c0b2f319def1f9675d22bba7c507 Mon Sep 17 00:00:00 2001 From: donalevans Date: Mon, 12 Jan 2026 06:48:19 -0800 Subject: [PATCH 5/5] Fix assertion --- .../org/elasticsearch/inference/InferenceStringGroupTests.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java b/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java index b68af794921c9..c873a21d936a1 100644 --- a/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java +++ b/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java @@ -30,6 +30,7 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; public class InferenceStringGroupTests extends AbstractBWCSerializationTestCase { @@ -126,7 +127,7 @@ public void testContainsNonTextEntry_withNonTextInput() { public void testIndexContainingMultipleInferenceStrings_withSingleInferenceString() { var inputs = getInputsList(); - assertThat(indexContainingMultipleInferenceStrings(inputs), is(null)); + assertThat(indexContainingMultipleInferenceStrings(inputs), nullValue()); } public void testIndexContainingMultipleInferenceStrings_withMultipleInferenceStrings() {