diff --git a/docs/changelog/140323.yaml b/docs/changelog/140323.yaml new file mode 100644 index 0000000000000..1ba19da787a22 --- /dev/null +++ b/docs/changelog/140323.yaml @@ -0,0 +1,5 @@ +pr: 140323 +summary: "[Inference API] Add support for embedding task to JinaAI service" +area: Inference +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java b/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java index 29bf2727562f5..637576ef29870 100644 --- a/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java @@ -9,6 +9,7 @@ package org.elasticsearch.inference; +import org.elasticsearch.TransportVersion; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -26,9 +27,11 @@ import java.io.IOException; import java.util.List; +import java.util.Map; import java.util.Objects; import static java.util.Collections.singletonList; +import static org.elasticsearch.inference.ModelConfigurations.TASK_SETTINGS; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -62,18 +65,25 @@ * OR *
  * "input": ["first text input", "second text input"]
- * @param inputs The list of {@link InferenceStringGroup} inputs to generate embeddings for - * @param inputType The {@link InputType} of the request + * + * @param inputs The list of {@link InferenceStringGroup} inputs to generate embeddings for + * @param inputType The {@link InputType} of the request + * @param taskSettings The map of task settings specific to this request */ -public record EmbeddingRequest(List inputs, InputType inputType) implements Writeable, ToXContentFragment { +public record EmbeddingRequest(List inputs, InputType inputType, Map taskSettings) + implements + Writeable, + ToXContentFragment { - private static final String INPUT_FIELD = "input"; + public static final TransportVersion JINA_AI_EMBEDDING_TASK_ADDED = TransportVersion.fromName("jina_ai_embedding_task_added"); + + public static final String INPUT_FIELD = "input"; private static final String INPUT_TYPE_FIELD = "input_type"; @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( EmbeddingRequest.class.getSimpleName(), - args -> new EmbeddingRequest((List) args[0], (InputType) args[1]) + args -> new EmbeddingRequest((List) args[0], (InputType) args[1], (Map) args[2]) ); static { @@ -89,31 +99,48 @@ public record EmbeddingRequest(List inputs, InputType inpu new ParseField(INPUT_TYPE_FIELD), ObjectParser.ValueType.STRING ); + PARSER.declareField( + optionalConstructorArg(), + (parser, context) -> parser.mapOrdered(), + new ParseField(TASK_SETTINGS), + ObjectParser.ValueType.OBJECT + ); } public static EmbeddingRequest of(List contents) { - return new EmbeddingRequest(contents, null); + return new EmbeddingRequest(contents, null, null); } - public EmbeddingRequest(List inputs, @Nullable InputType inputType) { + public EmbeddingRequest(List inputs, @Nullable InputType inputType, @Nullable Map taskSettings) { this.inputs = inputs; this.inputType = Objects.requireNonNullElse(inputType, InputType.UNSPECIFIED); + this.taskSettings = Objects.requireNonNullElse(taskSettings, Map.of()); } public EmbeddingRequest(StreamInput in) throws IOException { - this(in.readCollectionAsImmutableList(InferenceStringGroup::new), in.readEnum(InputType.class)); + this( + in.readCollectionAsImmutableList(InferenceStringGroup::new), + in.readEnum(InputType.class), + in.getTransportVersion().supports(JINA_AI_EMBEDDING_TASK_ADDED) ? in.readGenericMap() : Map.of() + ); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeCollection(inputs); out.writeEnum(inputType); + if (out.getTransportVersion().supports(JINA_AI_EMBEDDING_TASK_ADDED)) { + out.writeGenericMap(taskSettings); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.field(INPUT_FIELD, inputs); builder.field(INPUT_TYPE_FIELD, inputType); + if (taskSettings.isEmpty() == false) { + builder.field(TASK_SETTINGS, taskSettings); + } return builder; } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java b/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java index 986503d059120..9db304e58a987 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.List; +import java.util.Objects; import static java.util.Collections.singletonList; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; @@ -37,20 +38,31 @@ * ] * } * - * @param inferenceStrings the list of {@link InferenceString} which should result in generating a single embedding vector */ -public record InferenceStringGroup(List inferenceStrings) implements Writeable, ToXContentObject { - private static final String CONTENT_FIELD = "content"; +public final class InferenceStringGroup implements Writeable, ToXContentObject { + public static final String CONTENT_FIELD = "content"; @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( InferenceStringGroup.class.getSimpleName(), args -> new InferenceStringGroup((List) args[0]) ); + static { PARSER.declareObjectArray(constructorArg(), InferenceString.PARSER::apply, new ParseField(CONTENT_FIELD)); } + private final List inferenceStrings; + private final boolean containsNonTextEntry; + + /** + * @param inferenceStrings the list of {@link InferenceString} which should result in generating a single embedding vector + */ + public InferenceStringGroup(List inferenceStrings) { + this.inferenceStrings = inferenceStrings; + containsNonTextEntry = inferenceStrings.stream().anyMatch(s -> s.isText() == false); + } + public InferenceStringGroup(StreamInput in) throws IOException { this(in.readCollectionAsImmutableList(InferenceString::new)); } @@ -64,6 +76,18 @@ public InferenceStringGroup(String input) { this(singletonList(new InferenceString(DataType.TEXT, input))); } + public List inferenceStrings() { + return inferenceStrings; + } + + public boolean containsNonTextEntry() { + return containsNonTextEntry; + } + + public boolean containsMultipleInferenceStrings() { + return inferenceStrings.size() > 1; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeCollection(inferenceStrings); @@ -81,7 +105,7 @@ public static InferenceStringGroup parse(XContentParser parser) throws IOExcepti var token = parser.currentToken(); if (token == XContentParser.Token.VALUE_STRING) { // Create content object from String - return new InferenceStringGroup(singletonList(new InferenceString(DataType.TEXT, parser.text()))); + return new InferenceStringGroup(parser.text()); } else if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { // Create content object from InferenceString(s) return InferenceStringGroup.PARSER.apply(parser, null); @@ -134,4 +158,51 @@ public static List toInferenceStringList(List toStringList(List inferenceStringGroups) { return InferenceString.toStringList(toInferenceStringList(inferenceStringGroups)); } + + /** + * Method used to determine if a list of {@link InferenceStringGroup} contains any {@link InferenceString} that represent a non-text + * value + * + * @param inferenceStringGroups the list of {@link InferenceStringGroup} to check + * @return true if the input list contains any non-text values, false otherwise + */ + public static boolean containsNonTextEntry(List inferenceStringGroups) { + return inferenceStringGroups.stream().anyMatch(InferenceStringGroup::containsNonTextEntry); + } + + /** + * Method used to determine if a list of {@link InferenceStringGroup} contains any with more than one {@link InferenceString} in them + * + * @param inferenceStringGroups the list of {@link InferenceStringGroup} to check + * @return the index of the first {@link InferenceStringGroup} found to contain more than one {@link InferenceString}, or null if no + * elements in the list contain more than one {@link InferenceString} + */ + public static Integer indexContainingMultipleInferenceStrings(List inferenceStringGroups) { + for (int i = 0; i < inferenceStringGroups.size(); i++) { + InferenceStringGroup inferenceStringGroup = inferenceStringGroups.get(i); + if (inferenceStringGroup.containsMultipleInferenceStrings()) { + return i; + } + } + return null; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) return true; + if (obj == null || obj.getClass() != this.getClass()) return false; + var that = (InferenceStringGroup) obj; + return Objects.equals(this.inferenceStrings, that.inferenceStrings); + } + + @Override + public int hashCode() { + return Objects.hash(inferenceStrings); + } + + @Override + public String toString() { + return "InferenceStringGroup[" + "inferenceStrings=" + inferenceStrings + ']'; + } + } diff --git a/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java b/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java index f97239f8df8df..81348f65e4847 100644 --- a/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java @@ -55,6 +55,10 @@ default DenseVectorFieldMapper.ElementType elementType() { return null; } + default boolean isMultimodal() { + return false; + } + /** * The model to use in the inference endpoint (e.g. text-embedding-ada-002). Sometimes the model is not defined in the service * settings. This can happen for external providers (e.g. hugging face, azure ai studio) where the provider requires that the model diff --git a/server/src/main/resources/transport/definitions/referable/jina_ai_embedding_task_added.csv b/server/src/main/resources/transport/definitions/referable/jina_ai_embedding_task_added.csv new file mode 100644 index 0000000000000..9f242502d33af --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/jina_ai_embedding_task_added.csv @@ -0,0 +1 @@ +9259000 diff --git a/server/src/main/resources/transport/upper_bounds/9.4.csv b/server/src/main/resources/transport/upper_bounds/9.4.csv index 77cff68200459..008c2f07ab390 100644 --- a/server/src/main/resources/transport/upper_bounds/9.4.csv +++ b/server/src/main/resources/transport/upper_bounds/9.4.csv @@ -1 +1 @@ -esql_vsr_converters_used,9258000 +jina_ai_embedding_task_added,9259000 diff --git a/server/src/test/java/org/elasticsearch/inference/EmbeddingRequestTests.java b/server/src/test/java/org/elasticsearch/inference/EmbeddingRequestTests.java index cbf78186ee142..1b120d0d19b21 100644 --- a/server/src/test/java/org/elasticsearch/inference/EmbeddingRequestTests.java +++ b/server/src/test/java/org/elasticsearch/inference/EmbeddingRequestTests.java @@ -21,7 +21,10 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Map; +import static org.elasticsearch.inference.EmbeddingRequest.JINA_AI_EMBEDDING_TASK_ADDED; +import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.is; public class EmbeddingRequestTests extends AbstractBWCSerializationTestCase { @@ -40,6 +43,7 @@ public void testParser_withSingleString() throws IOException { ); assertThat(request.inputs(), is(expectedInputs)); assertThat(request.inputType(), is(InputType.SEARCH)); + assertThat(request.taskSettings(), anEmptyMap()); } } @@ -60,6 +64,7 @@ public void testParser_withSingleContentObject() throws IOException { ); assertThat(request.inputs(), is(expectedInputs)); assertThat(request.inputType(), is(InputType.SEARCH)); + assertThat(request.taskSettings(), anEmptyMap()); } } @@ -78,6 +83,7 @@ public void testParser_withStringArray() throws IOException { ); assertThat(request.inputs(), is(expectedInputs)); assertThat(request.inputType(), is(InputType.SEARCH)); + assertThat(request.taskSettings(), anEmptyMap()); } } @@ -106,6 +112,7 @@ public void testParser_withSingleContentObjectWithMultipleEntries() throws IOExc ); assertThat(request.inputs(), is(expectedInputs)); assertThat(request.inputType(), is(InputType.SEARCH)); + assertThat(request.taskSettings(), anEmptyMap()); } } @@ -140,6 +147,7 @@ public void testParser_withMultipleContentObjects() throws IOException { ); assertThat(request.inputs(), is(expectedInputs)); assertThat(request.inputType(), is(InputType.SEARCH)); + assertThat(request.taskSettings(), anEmptyMap()); } } @@ -173,6 +181,7 @@ public void testParser_withUnspecifiedFormats_usesDefaults() throws IOException ); assertThat(request.inputs(), is(expectedInputs)); assertThat(request.inputType(), is(InputType.SEARCH)); + assertThat(request.taskSettings(), anEmptyMap()); } } @@ -189,6 +198,46 @@ public void testParser_withNoInputType() throws IOException { ); assertThat(request.inputs(), is(expectedInputs)); assertThat(request.inputType(), is(InputType.UNSPECIFIED)); + assertThat(request.taskSettings(), anEmptyMap()); + } + } + + public void testParser_withTaskSettings() throws IOException { + var requestJson = """ + { + "input": "some text input", + "task_settings": { + "field_one": "value_one", + "field_two": 123 + } + } + """; + try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { + var request = EmbeddingRequest.PARSER.apply(parser, null); + var expectedInputs = List.of( + new InferenceStringGroup(List.of(new InferenceString(DataType.TEXT, DataFormat.TEXT, "some text input"))) + ); + assertThat(request.inputs(), is(expectedInputs)); + assertThat(request.inputType(), is(InputType.UNSPECIFIED)); + assertThat(request.taskSettings(), is(Map.of("field_one", "value_one", "field_two", 123))); + } + } + + public void testParser_withEmptyTaskSettings() throws IOException { + var requestJson = """ + { + "input": "some text input", + "task_settings": {} + } + """; + try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { + var request = EmbeddingRequest.PARSER.apply(parser, null); + var expectedInputs = List.of( + new InferenceStringGroup(List.of(new InferenceString(DataType.TEXT, DataFormat.TEXT, "some text input"))) + ); + assertThat(request.inputs(), is(expectedInputs)); + assertThat(request.inputType(), is(InputType.UNSPECIFIED)); + assertThat(request.taskSettings(), anEmptyMap()); } } @@ -203,7 +252,11 @@ protected EmbeddingRequest createTestInstance() { } public static EmbeddingRequest createRandom() { - return new EmbeddingRequest(randomEmbeddingContents(), randomFrom(InputType.values())); + return new EmbeddingRequest( + randomEmbeddingContents(), + randomFrom(InputType.values()), + Map.of(randomAlphanumericOfLength(8), randomAlphanumericOfLength(8)) + ); } private static List randomEmbeddingContents() { @@ -216,21 +269,27 @@ private static List randomEmbeddingContents() { @Override protected EmbeddingRequest mutateInstance(EmbeddingRequest instance) throws IOException { - if (randomBoolean()) { - var embeddingContents = instance.inputs(); - return new EmbeddingRequest( - randomValueOtherThan(embeddingContents, EmbeddingRequestTests::randomEmbeddingContents), - instance.inputType() + var embeddingContents = instance.inputs(); + var inputType = instance.inputType(); + var taskSettings = instance.taskSettings(); + switch (randomInt(2)) { + case 0 -> embeddingContents = randomValueOtherThan(embeddingContents, EmbeddingRequestTests::randomEmbeddingContents); + case 1 -> inputType = randomValueOtherThan(inputType, () -> randomFrom(InputType.values())); + case 2 -> taskSettings = randomValueOtherThan( + taskSettings, + () -> Map.of(randomAlphanumericOfLength(8), randomAlphanumericOfLength(8)) ); - } else { - InputType inputType = instance.inputType(); - return new EmbeddingRequest(instance.inputs(), randomValueOtherThan(inputType, () -> randomFrom(InputType.values()))); } + return new EmbeddingRequest(embeddingContents, inputType, taskSettings); } @Override protected EmbeddingRequest mutateInstanceForVersion(EmbeddingRequest instance, TransportVersion version) { - return instance; + if (version.supports(JINA_AI_EMBEDDING_TASK_ADDED)) { + return instance; + } else { + return new EmbeddingRequest(instance.inputs(), instance.inputType(), Map.of()); + } } @Override diff --git a/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java b/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java index ed1f7db68b497..c873a21d936a1 100644 --- a/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java +++ b/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java @@ -14,29 +14,63 @@ import org.elasticsearch.inference.InferenceString.DataFormat; import org.elasticsearch.inference.InferenceString.DataType; import org.elasticsearch.test.AbstractBWCSerializationTestCase; +import org.elasticsearch.xcontent.XContentParseException; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.json.JsonXContent; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import static org.elasticsearch.inference.InferenceString.DataType.TEXT; +import static org.elasticsearch.inference.InferenceStringGroup.containsNonTextEntry; +import static org.elasticsearch.inference.InferenceStringGroup.indexContainingMultipleInferenceStrings; import static org.elasticsearch.inference.InferenceStringGroup.toInferenceStringList; import static org.elasticsearch.inference.InferenceStringGroup.toStringList; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; public class InferenceStringGroupTests extends AbstractBWCSerializationTestCase { public void testStringConstructor() { String stringValue = "a string"; var input = new InferenceStringGroup(stringValue); - assertThat(input.inferenceStrings(), contains(new InferenceString(DataType.TEXT, DataFormat.TEXT, stringValue))); + assertThat(input.inferenceStrings(), contains(new InferenceString(TEXT, DataFormat.TEXT, stringValue))); + assertThat(input.containsNonTextEntry(), is(false)); + assertThat(input.containsMultipleInferenceStrings(), is(false)); } - public void testSingleArgumentConstructor() { + public void testSingleInferenceStringConstructor() { InferenceString inferenceString = new InferenceString(DataType.IMAGE, DataFormat.BASE64, "a string"); var input = new InferenceStringGroup(inferenceString); assertThat(input.inferenceStrings(), contains(inferenceString)); + assertThat(input.containsNonTextEntry(), is(true)); + assertThat(input.containsMultipleInferenceStrings(), is(false)); + } + + public void testInferenceStringListConstructor() { + InferenceString inferenceString1 = new InferenceString(DataType.IMAGE, DataFormat.BASE64, "a string"); + InferenceString inferenceString2 = new InferenceString(TEXT, DataFormat.TEXT, "a string"); + var input = new InferenceStringGroup(List.of(inferenceString1, inferenceString2)); + assertThat(input.inferenceStrings(), contains(inferenceString1, inferenceString2)); + assertThat(input.containsNonTextEntry(), is(true)); + assertThat(input.containsMultipleInferenceStrings(), is(true)); + } + + public void testParser_withEmptyContentObject_throws() throws IOException { + var requestJson = """ + { + "content": {} + } + """; + try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { + // Need to call nextToken() so that the parser is at the correct element + parser.nextToken(); + var exception = expectThrows(XContentParseException.class, () -> InferenceStringGroup.PARSER.apply(parser, null)); + assertThat(exception.getMessage(), containsString("[InferenceStringGroup] failed to parse field [content]")); + } } public void testValue_withMoreThanOneElement_throws() { @@ -77,6 +111,46 @@ public void testToStringList_withMoreThanOneElement_throws() { assertThat(expectedException.getMessage(), is("Multiple-input InferenceStringGroup passed to InferenceStringGroup.toStringList")); } + public void testContainsNonTextEntry_withOnlyTextInputs() { + var inputs = List.of(new InferenceStringGroup("string1"), new InferenceStringGroup("string2")); + assertThat(containsNonTextEntry(inputs), is(false)); + } + + public void testContainsNonTextEntry_withNonTextInput() { + DataType nonTextDataType = randomValueOtherThan(TEXT, () -> randomFrom(DataType.values())); + var inputs = List.of( + new InferenceStringGroup("string1"), + new InferenceStringGroup(new InferenceString(nonTextDataType, "non text")) + ); + assertThat(containsNonTextEntry(inputs), is(true)); + } + + public void testIndexContainingMultipleInferenceStrings_withSingleInferenceString() { + var inputs = getInputsList(); + assertThat(indexContainingMultipleInferenceStrings(inputs), nullValue()); + } + + public void testIndexContainingMultipleInferenceStrings_withMultipleInferenceStrings() { + var inputs = getInputsList(); + + // Add an InferenceStringGroup with multiple InferenceStrings at a random point in the input list + var indexToAdd = randomIntBetween(0, inputs.size() - 1); + var multipleInferenceStrings = new InferenceStringGroup( + List.of(new InferenceString(TEXT, "a_string"), new InferenceString(TEXT, "a_string")) + ); + inputs.add(indexToAdd, multipleInferenceStrings); + assertThat(indexContainingMultipleInferenceStrings(inputs), is(indexToAdd)); + } + + private static ArrayList getInputsList() { + var listSize = randomIntBetween(1, 10); + var inputs = new ArrayList(listSize); + for (int i = 0; i < listSize; ++i) { + inputs.add(new InferenceStringGroup("a_string")); + } + return inputs; + } + @Override protected InferenceStringGroup mutateInstanceForVersion(InferenceStringGroup instance, TransportVersion version) { return instance; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/EmbeddingActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/EmbeddingActionRequestTests.java index 1de31682a3524..553fc6ea99b20 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/EmbeddingActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/EmbeddingActionRequestTests.java @@ -23,7 +23,9 @@ import java.io.IOException; import java.util.List; +import java.util.Map; +import static org.elasticsearch.inference.EmbeddingRequest.JINA_AI_EMBEDDING_TASK_ADDED; import static org.elasticsearch.xpack.core.inference.action.EmbeddingAction.Request.parseRequest; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -39,7 +41,10 @@ public void testParseRequest() throws IOException { "content": {"type": "image", "format": "base64", "value": "some image input" } } ], - "input_type": "search" + "input_type": "search", + "task_settings": { + "field": "value" + } } """; try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { @@ -53,7 +58,8 @@ public void testParseRequest() throws IOException { taskType, new EmbeddingRequest( List.of(new InferenceStringGroup(new InferenceString(DataType.IMAGE, DataFormat.BASE64, "some image input"))), - InputType.SEARCH + InputType.SEARCH, + Map.of("field", "value") ), context, timeout @@ -73,7 +79,7 @@ public void testValidate_withNullEmbeddingRequestInputs_returnsValidationExcepti var request = new EmbeddingAction.Request( randomAlphanumericOfLength(8), TaskType.EMBEDDING, - new EmbeddingRequest(null, randomFrom(InputType.values())), + new EmbeddingRequest(null, randomFrom(InputType.values()), Map.of()), new InferenceContext(randomAlphaOfLength(10)), TimeValue.timeValueMillis(randomLongBetween(1, 2048)) ); @@ -87,7 +93,7 @@ public void testValidate_withEmptyEmbeddingRequestInputs_returnsValidationExcept var request = new EmbeddingAction.Request( randomAlphanumericOfLength(8), TaskType.EMBEDDING, - new EmbeddingRequest(List.of(), randomFrom(InputType.values())), + new EmbeddingRequest(List.of(), randomFrom(InputType.values()), Map.of()), new InferenceContext(randomAlphaOfLength(10)), TimeValue.timeValueMillis(randomLongBetween(1, 2048)) ); @@ -112,10 +118,11 @@ public void testValidate_withNonEmbeddingTaskType_returnsValidationException() { } public void testValidate_withMultipleValidationErrors_returnsAll() { + // Create a request with an invalid task type and null inputs var request = new EmbeddingAction.Request( randomAlphanumericOfLength(8), randomValueOtherThanMany(TaskType.EMBEDDING::isAnyOrSame, () -> randomFrom(TaskType.values())), - new EmbeddingRequest(null, randomFrom(InputType.values())), + new EmbeddingRequest(null, randomFrom(InputType.values()), Map.of()), new InferenceContext(randomAlphaOfLength(10)), TimeValue.timeValueMillis(randomLongBetween(1, 2048)) ); @@ -128,16 +135,25 @@ public void testValidate_withMultipleValidationErrors_returnsAll() { @Override protected EmbeddingAction.Request mutateInstanceForVersion(EmbeddingAction.Request instance, TransportVersion version) { + // Use empty task settings if node is on a version before Jina AI embedding task support was added, since embedding request task + // settings were added with that change + var embeddingRequest = instance.getEmbeddingRequest(); + if (version.supports(JINA_AI_EMBEDDING_TASK_ADDED) == false) { + embeddingRequest = new EmbeddingRequest(embeddingRequest.inputs(), embeddingRequest.inputType(), Map.of()); + } + + var context = instance.getContext(); if (version.supports(INFERENCE_CONTEXT) == false) { - return new EmbeddingAction.Request( - instance.getInferenceEntityId(), - instance.getTaskType(), - instance.getEmbeddingRequest(), - InferenceContext.EMPTY_INSTANCE, - instance.getTimeout() - ); + context = InferenceContext.EMPTY_INSTANCE; } - return instance; + + return new EmbeddingAction.Request( + instance.getInferenceEntityId(), + instance.getTaskType(), + embeddingRequest, + context, + instance.getTimeout() + ); } @Override @@ -160,7 +176,11 @@ public static EmbeddingAction.Request createRandom() { } private static EmbeddingRequest randomEmbeddingRequest() { - return new EmbeddingRequest(List.of(new InferenceStringGroup(randomAlphanumericOfLength(8))), randomFrom(InputType.values())); + return new EmbeddingRequest( + List.of(new InferenceStringGroup(randomAlphanumericOfLength(8))), + randomFrom(InputType.values()), + Map.of(randomAlphanumericOfLength(8), randomAlphanumericOfLength(8)) + ); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java index a7a6f78e26b7f..771512f0d1f13 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java @@ -15,14 +15,13 @@ import java.util.HashMap; import java.util.Map; +import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.DEFAULT_SETTINGS; import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.ELASTIC_RERANKER_EXTRA_TOKEN_COUNT; import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.ELASTIC_RERANKER_TOKEN_LIMIT; import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.WORDS_PER_TOKEN; public class ChunkingSettingsBuilderTests extends ESTestCase { - public static final SentenceBoundaryChunkingSettings DEFAULT_SETTINGS = new SentenceBoundaryChunkingSettings(250, 1); - public void testNullChunkingSettingsMap() { ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.fromMap(null); assertEquals(ChunkingSettingsBuilder.OLD_DEFAULT_SETTINGS, chunkingSettings); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 69976cb5d6b82..b5c93171dcb29 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -233,7 +233,7 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { } public void testGetServicesWithEmbeddingTaskType() throws IOException { - assertThat(providersFor(TaskType.EMBEDDING), containsInAnyOrder(List.of("text_embedding_test_service").toArray())); + assertThat(providersFor(TaskType.EMBEDDING), containsInAnyOrder(List.of("text_embedding_test_service", "jinaai").toArray())); } private List getAllServices() throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index b5f2bc89e0ec8..ea666531bedf2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -106,8 +106,9 @@ import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; -import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAITextEmbeddingServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings; import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionServiceSettings; @@ -860,8 +861,15 @@ private static void addJinaAINamedWriteables(List namedWriteables.add( new NamedWriteableRegistry.Entry( ServiceSettings.class, - JinaAIEmbeddingsServiceSettings.NAME, - JinaAIEmbeddingsServiceSettings::new + JinaAITextEmbeddingServiceSettings.NAME, + JinaAITextEmbeddingServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + JinaAIEmbeddingServiceSettings.NAME, + JinaAIEmbeddingServiceSettings::new ) ); namedWriteables.add( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 06c6c6e3aeeac..1e3d178da9c2f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -43,6 +43,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceString; import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.MinimalServiceSettings; @@ -79,6 +80,7 @@ import java.util.Map; import java.util.stream.Collectors; +import static java.util.Collections.singletonList; import static org.elasticsearch.inference.telemetry.InferenceStats.serviceAndResponseAttributes; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy; @@ -396,8 +398,14 @@ private void executeChunkedInferenceAsync( } } + // This assumes that all inference requests are text only, with no images final List inputs = requests.stream() - .map(r -> new ChunkInferenceInput(new InferenceStringGroup(r.input), r.chunkingSettings)) + .map( + r -> new ChunkInferenceInput( + new InferenceStringGroup(singletonList(new InferenceString(InferenceString.DataType.TEXT, r.input))), + r.chunkingSettings + ) + ) .collect(Collectors.toList()); ActionListener> completionListener = ActionListener.wrap(results -> { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 47b394eed3c39..60a4ee2cf48a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -21,6 +21,7 @@ import org.elasticsearch.inference.EmbeddingRequest; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; @@ -41,6 +42,7 @@ import java.util.Objects; import java.util.Set; +import static org.elasticsearch.inference.InferenceStringGroup.indexContainingMultipleInferenceStrings; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedEmbeddingOperation; public abstract class SenderService implements InferenceService { @@ -138,9 +140,38 @@ public void unifiedCompletionInfer( @Override public void embeddingInfer(Model model, EmbeddingRequest request, TimeValue timeout, ActionListener listener) { - SubscribableListener.newForked(this::init) - .andThen((embeddingInferListener) -> doEmbeddingInfer(model, request, timeout, embeddingInferListener)) - .addListener(listener); + SubscribableListener.newForked(this::init).andThen((embeddingInferListener) -> { + if (supportsMultipleItemsPerContent()) { + doEmbeddingInfer(model, request, timeout, embeddingInferListener); + } else { + var index = indexContainingMultipleInferenceStrings(request.inputs()); + if (index == null) { + doEmbeddingInfer(model, request, timeout, embeddingInferListener); + } else { + listener.onFailure( + new ElasticsearchStatusException( + Strings.format( + "Field [%1$s] must contain a single item for [%2$s] service. " + + "[%1$s] object with multiple items found at $.%3$s.%1$s[%4$d]", + InferenceStringGroup.CONTENT_FIELD, + name(), + EmbeddingRequest.INPUT_FIELD, + index + ), + RestStatus.BAD_REQUEST + ) + ); + } + } + }).addListener(listener); + } + + /** + * Override as necessary for services which support generating a single embedding vector from multiple inputs + * @return true if the service supports sending embedding requests where multiple inputs are used to generate a single embedding vector + */ + protected boolean supportsMultipleItemsPerContent() { + return false; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java index 6fed5218a93ad..bd87a04afab78 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java @@ -29,6 +29,11 @@ public final class ServiceFields { * The field used by services other than elasticsearch to determine the embedding type */ public static final String EMBEDDING_TYPE = "embedding_type"; + /** + * The name of the field used to specify whether the model supports multimodal inputs for the + * {@link org.elasticsearch.inference.TaskType#EMBEDDING} task type. Defaults to true. + */ + public static final String MULTIMODAL_MODEL = "multimodal_model"; private ServiceFields() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java index ab72aab578365..21f72872627c7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java @@ -11,7 +11,6 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; @@ -22,7 +21,6 @@ import org.elasticsearch.xpack.inference.services.jinaai.request.JinaAIEmbeddingsRequest; import org.elasticsearch.xpack.inference.services.jinaai.response.JinaAIEmbeddingsResponseEntity; -import java.util.List; import java.util.Objects; import java.util.function.Supplier; @@ -52,11 +50,11 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); - List docsInput = input.getTextInputs(); - InputType inputType = input.getInputType(); + var embeddingsInput = inferenceInputs.castTo(EmbeddingsInput.class); + var inferenceStringGroups = embeddingsInput.getInputs(); + var inputType = embeddingsInput.getInputType(); - JinaAIEmbeddingsRequest request = new JinaAIEmbeddingsRequest(docsInput, inputType, model); + var request = new JinaAIEmbeddingsRequest(inferenceStringGroups, inputType, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index 88372c2c70244..a3f3734d22576 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.jinaai; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterService; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmbeddingRequest; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; @@ -30,8 +32,10 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.core.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; @@ -43,7 +47,6 @@ import org.elasticsearch.xpack.inference.services.jinaai.action.JinaAIActionCreator; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -54,6 +57,7 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.inference.InferenceStringGroup.containsNonTextEntry; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceFields.EMBEDDING_TYPE; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; @@ -66,6 +70,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceFields.EMBEDDING_MAX_BATCH_SIZE; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.updateEmbeddingDetails; public class JinaAIService extends SenderService implements RerankingInferenceService { @@ -74,7 +79,7 @@ public class JinaAIService extends SenderService implements RerankingInferenceSe public static final String NAME = "jinaai"; private static final String SERVICE_NAME = "Jina AI"; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK); + private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.EMBEDDING); public static final EnumSet VALID_INPUT_TYPE_VALUES = EnumSet.of( InputType.INGEST, @@ -114,7 +119,7 @@ public void parseRequestConfig( Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + if (TaskType.TEXT_EMBEDDING.equals(taskType) || TaskType.EMBEDDING.equals(taskType)) { chunkingSettings = ChunkingSettingsBuilder.fromMap( removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) ); @@ -168,13 +173,14 @@ private static JinaAIModel createModel( ConfigurationParseContext context ) { return switch (taskType) { - case TEXT_EMBEDDING -> new JinaAIEmbeddingsModel( + case TEXT_EMBEDDING, EMBEDDING -> new JinaAIEmbeddingsModel( inferenceEntityId, serviceSettings, taskSettings, chunkingSettings, secretSettings, - context + context, + taskType ); case RERANK -> new JinaAIRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context); default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context); @@ -193,7 +199,7 @@ public JinaAIModel parsePersistedConfigWithSecrets( Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + if (TaskType.TEXT_EMBEDDING.equals(taskType) || TaskType.EMBEDDING.equals(taskType)) { chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } @@ -213,7 +219,7 @@ public JinaAIModel parsePersistedConfig(String inferenceEntityId, TaskType taskT Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + if (TaskType.TEXT_EMBEDDING.equals(taskType) || TaskType.EMBEDDING.equals(taskType)) { chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } @@ -285,7 +291,8 @@ protected void doChunkedInfer( boolean batchChunksAcrossInputs = true; if (jinaaiModel.getTaskSettings() instanceof JinaAIEmbeddingsTaskSettings jinaAIEmbeddingsTaskSettings) { batchChunksAcrossInputs = jinaAIEmbeddingsTaskSettings.getLateChunking() == null - || jinaAIEmbeddingsTaskSettings.getLateChunking() == false; + || jinaAIEmbeddingsTaskSettings.getLateChunking() == false + || inputs.stream().anyMatch(c -> c.input().containsNonTextEntry()); } List batchedRequests = new EmbeddingRequestChunker<>( @@ -301,25 +308,40 @@ protected void doChunkedInfer( } } + @Override + protected void doEmbeddingInfer( + Model model, + EmbeddingRequest request, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof JinaAIEmbeddingsModel jinaAIModel) { + if (model.getServiceSettings().isMultimodal() == false && containsNonTextEntry(request.inputs())) { + listener.onFailure(new ElasticsearchStatusException("Non-text input provided for text-only model", RestStatus.BAD_REQUEST)); + } else { + var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents()); + + ExecutableAction action = jinaAIModel.accept(actionCreator, request.taskSettings()); + action.execute(new EmbeddingsInput(request::inputs, request.inputType()), timeout, listener); + } + } else { + listener.onFailure(createInvalidModelException(model)); + } + + } + @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof JinaAIEmbeddingsModel embeddingsModel) { var serviceSettings = embeddingsModel.getServiceSettings(); var similarityFromModel = serviceSettings.similarity(); var similarityToUse = similarityFromModel == null ? defaultSimilarity(serviceSettings.getEmbeddingType()) : similarityFromModel; - var maxInputTokens = serviceSettings.maxInputTokens(); - - var updatedServiceSettings = new JinaAIEmbeddingsServiceSettings( - new JinaAIServiceSettings( - serviceSettings.getCommonSettings().modelId(), - serviceSettings.getCommonSettings().rateLimitSettings() - ), - similarityToUse, - embeddingSize, - maxInputTokens, - serviceSettings.getEmbeddingType(), - serviceSettings.dimensionsSetByUser() - ); + + var updatedServiceSettings = updateEmbeddingDetails(serviceSettings, embeddingSize, similarityToUse); + + if (updatedServiceSettings.equals(serviceSettings)) { + return model; + } return new JinaAIEmbeddingsModel(embeddingsModel, updatedServiceSettings); } else { @@ -379,7 +401,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( DIMENSIONS, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription( + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.EMBEDDING)).setDescription( "The number of dimensions the resulting embeddings should have. For more information refer to " + "https://api.jina.ai/docs#tag/embeddings/operation/create_embedding_v1_embeddings_post." ) @@ -393,7 +415,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( EMBEDDING_TYPE, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription( + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.EMBEDDING)).setDescription( Strings.format( "The type of embedding to return. One of %s. bit and binary are equivalent and are encoded as " + "bytes with signed int8 precision.", @@ -411,7 +433,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( SIMILARITY, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription( + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.EMBEDDING)).setDescription( Strings.format( "The similarity measure. One of %s. For float embeddings, the default similarity " + "is dot_product. For bit and binary embeddings, the default similarity is l2_norm.", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettings.java similarity index 65% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettings.java index 134e1a0029d80..833c58b3118e4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettings.java @@ -26,9 +26,12 @@ import java.util.EnumSet; import java.util.Map; import java.util.Objects; +import java.util.function.BiFunction; +import static org.elasticsearch.inference.EmbeddingRequest.JINA_AI_EMBEDDING_TASK_ADDED; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS_SET_BY_USER; +import static org.elasticsearch.xpack.inference.services.ServiceFields.EMBEDDING_TYPE; import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; @@ -36,21 +39,44 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; -public class JinaAIEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { - public static final String NAME = "jinaai_embeddings_service_settings"; +public abstract class BaseJinaAIEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { - public static JinaAIEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED = TransportVersion.fromName("jina_ai_embedding_type_support_added"); + + static final TransportVersion JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED = TransportVersion.fromName( + "jina_ai_embedding_dimensions_support_added" + ); + + @FunctionalInterface + public interface ConstructorInvoker { + T construct( + JinaAIServiceSettings commonSettings, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + @Nullable JinaAIEmbeddingType embeddingType, + boolean dimensionsSetByUser, + boolean multimodalModel + ); + } + + static T fromMap( + Map map, + ConfigurationParseContext context, + BiFunction, ValidationException, Boolean> handleMultimodalModelField, + ConstructorInvoker constructor + ) { ValidationException validationException = new ValidationException(); var commonServiceSettings = JinaAIServiceSettings.fromMap(map, context); SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); Integer dimensions = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); - Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class, validationException); JinaAIEmbeddingType embeddingType = parseEmbeddingType(map, validationException); Boolean dimensionsSetByUser; if (context == ConfigurationParseContext.PERSISTENT) { - dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class); + dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class, validationException); if (dimensionsSetByUser == null) { dimensionsSetByUser = Boolean.FALSE; } @@ -58,17 +84,20 @@ public static JinaAIEmbeddingsServiceSettings fromMap(Map map, C dimensionsSetByUser = dimensions != null; } + boolean multimodalModel = handleMultimodalModelField.apply(map, validationException); + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new JinaAIEmbeddingsServiceSettings( + return constructor.construct( commonServiceSettings, similarity, dimensions, maxInputTokens, embeddingType, - dimensionsSetByUser + dimensionsSetByUser, + multimodalModel ); } @@ -76,7 +105,7 @@ static JinaAIEmbeddingType parseEmbeddingType(Map map, Validatio return Objects.requireNonNullElse( extractOptionalEnum( map, - ServiceFields.EMBEDDING_TYPE, + EMBEDDING_TYPE, ModelConfigurations.SERVICE_SETTINGS, JinaAIEmbeddingType::fromString, EnumSet.allOf(JinaAIEmbeddingType.class), @@ -86,28 +115,33 @@ static JinaAIEmbeddingType parseEmbeddingType(Map map, Validatio ); } - private static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED = TransportVersion.fromName( - "jina_ai_embedding_type_support_added" - ); - - static final TransportVersion JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED = TransportVersion.fromName( - "jina_ai_embedding_dimensions_support_added" - ); + public static BaseJinaAIEmbeddingsServiceSettings updateEmbeddingDetails( + BaseJinaAIEmbeddingsServiceSettings existingSettings, + Integer embeddingSize, + SimilarityMeasure similarityToUse + ) { + if (embeddingSize.equals(existingSettings.dimensions()) && similarityToUse.equals(existingSettings.similarity())) { + return existingSettings; + } + return existingSettings.update(similarityToUse, embeddingSize); + } private final JinaAIServiceSettings commonSettings; private final SimilarityMeasure similarity; private final Integer dimensions; private final Integer maxInputTokens; private final JinaAIEmbeddingType embeddingType; - private final Boolean dimensionsSetByUser; + private final boolean dimensionsSetByUser; + private final boolean multimodalModel; - public JinaAIEmbeddingsServiceSettings( + public BaseJinaAIEmbeddingsServiceSettings( JinaAIServiceSettings commonSettings, @Nullable SimilarityMeasure similarity, @Nullable Integer dimensions, @Nullable Integer maxInputTokens, @Nullable JinaAIEmbeddingType embeddingType, - boolean dimensionsSetByUser + boolean dimensionsSetByUser, + boolean multimodalModel ) { this.commonSettings = commonSettings; this.similarity = similarity; @@ -115,25 +149,43 @@ public JinaAIEmbeddingsServiceSettings( this.maxInputTokens = maxInputTokens; this.embeddingType = embeddingType != null ? embeddingType : JinaAIEmbeddingType.FLOAT; this.dimensionsSetByUser = dimensionsSetByUser; + this.multimodalModel = multimodalModel; } - public JinaAIEmbeddingsServiceSettings(StreamInput in) throws IOException { + public BaseJinaAIEmbeddingsServiceSettings(StreamInput in) throws IOException { this.commonSettings = new JinaAIServiceSettings(in); this.similarity = in.readOptionalEnum(SimilarityMeasure.class); this.dimensions = in.readOptionalVInt(); this.maxInputTokens = in.readOptionalVInt(); - - this.embeddingType = (in.getTransportVersion().supports(JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED)) - ? Objects.requireNonNullElse(in.readOptionalEnum(JinaAIEmbeddingType.class), JinaAIEmbeddingType.FLOAT) - : JinaAIEmbeddingType.FLOAT; + if (in.getTransportVersion().supports(JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED)) { + this.embeddingType = Objects.requireNonNullElse(in.readOptionalEnum(JinaAIEmbeddingType.class), JinaAIEmbeddingType.FLOAT); + } else { + this.embeddingType = JinaAIEmbeddingType.FLOAT; + } if (in.getTransportVersion().supports(JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED)) { this.dimensionsSetByUser = in.readBoolean(); } else { this.dimensionsSetByUser = false; } + + if (in.getTransportVersion().supports(JINA_AI_EMBEDDING_TASK_ADDED)) { + this.multimodalModel = in.readBoolean(); + } else { + this.multimodalModel = false; + } } + /** + * Returns a new {@link BaseJinaAIEmbeddingsServiceSettings} with updated similarity and dimensions but all other fields unchanged + * @param similarity the new similarity + * @param dimensions the new dimensions + * @return a new {@link BaseJinaAIEmbeddingsServiceSettings} + */ + public abstract BaseJinaAIEmbeddingsServiceSettings update(SimilarityMeasure similarity, Integer dimensions); + + protected abstract void optionallyWriteMultimodalField(XContentBuilder builder) throws IOException; + public JinaAIServiceSettings getCommonSettings() { return commonSettings; } @@ -172,8 +224,8 @@ public DenseVectorFieldMapper.ElementType elementType() { } @Override - public String getWriteableName() { - return NAME; + public boolean isMultimodal() { + return multimodalModel; } @Override @@ -200,10 +252,13 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil if (maxInputTokens != null) { builder.field(MAX_INPUT_TOKENS, maxInputTokens); } + if (similarity != null) { builder.field(SIMILARITY, similarity); } + optionallyWriteMultimodalField(builder); + return builder; } @@ -218,7 +273,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); out.writeOptionalVInt(dimensions); out.writeOptionalVInt(maxInputTokens); - if (out.getTransportVersion().supports(JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED)) { out.writeOptionalEnum(JinaAIEmbeddingType.translateToVersion(embeddingType, out.getTransportVersion())); } @@ -226,23 +280,48 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().supports(JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED)) { out.writeBoolean(dimensionsSetByUser); } + + if (out.getTransportVersion().supports(JINA_AI_EMBEDDING_TASK_ADDED)) { + out.writeOptionalBoolean(multimodalModel); + } } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - JinaAIEmbeddingsServiceSettings that = (JinaAIEmbeddingsServiceSettings) o; + BaseJinaAIEmbeddingsServiceSettings that = (BaseJinaAIEmbeddingsServiceSettings) o; return Objects.equals(commonSettings, that.commonSettings) && Objects.equals(similarity, that.similarity) && Objects.equals(dimensions, that.dimensions) && Objects.equals(maxInputTokens, that.maxInputTokens) && Objects.equals(embeddingType, that.embeddingType) - && Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser); + && Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser) + && Objects.equals(multimodalModel, that.multimodalModel); } @Override public int hashCode() { - return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens, embeddingType, dimensionsSetByUser); + return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens, embeddingType, dimensionsSetByUser, multimodalModel); + } + + @Override + public String toString() { + return "BaseJinaAIEmbeddingsServiceSettings{" + + "commonSettings=" + + commonSettings + + ", similarity=" + + similarity + + ", dimensions=" + + dimensions + + ", maxInputTokens=" + + maxInputTokens + + ", embeddingType=" + + embeddingType + + ", dimensionsSetByUser=" + + dimensionsSetByUser + + ", multimodalModel=" + + multimodalModel + + '}'; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettings.java new file mode 100644 index 0000000000000..43be6b00fd591 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettings.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.embeddings; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MULTIMODAL_MODEL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; + +public class JinaAIEmbeddingServiceSettings extends BaseJinaAIEmbeddingsServiceSettings { + public static final String NAME = "jinaai_multimodal_embedding_service_settings"; + public static final boolean DEFAULT_MULTIMODAL_MODEL = true; + + public static JinaAIEmbeddingServiceSettings fromMap(Map map, ConfigurationParseContext context) { + return BaseJinaAIEmbeddingsServiceSettings.fromMap( + map, + context, + (m, v) -> Objects.requireNonNullElse(removeAsType(m, MULTIMODAL_MODEL, Boolean.class, v), DEFAULT_MULTIMODAL_MODEL), + JinaAIEmbeddingServiceSettings::new + ); + } + + public JinaAIEmbeddingServiceSettings( + JinaAIServiceSettings commonSettings, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + @Nullable JinaAIEmbeddingType embeddingType, + boolean dimensionsSetByUser, + boolean multimodalModel + ) { + super(commonSettings, similarity, dimensions, maxInputTokens, embeddingType, dimensionsSetByUser, multimodalModel); + } + + public JinaAIEmbeddingServiceSettings(StreamInput in) throws IOException { + super(in); + } + + @Override + public BaseJinaAIEmbeddingsServiceSettings update(SimilarityMeasure similarity, Integer dimensions) { + return new JinaAIEmbeddingServiceSettings( + getCommonSettings(), + similarity, + dimensions, + maxInputTokens(), + getEmbeddingType(), + dimensionsSetByUser(), + isMultimodal() + ); + } + + @Override + protected void optionallyWriteMultimodalField(XContentBuilder builder) throws IOException { + builder.field(MULTIMODAL_MODEL, isMultimodal()); + } + + @Override + public String getWriteableName() { + return NAME; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingType.java index 3bccc940b7fe3..5d22d93e77594 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingType.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingType.java @@ -16,6 +16,8 @@ import java.util.Locale; import java.util.Map; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED; + /** * Defines the type of embedding that the Jina AI API should return for a request. * @@ -49,10 +51,6 @@ private static final class RequestConstants { ELEMENT_TYPE_TO_JINA_AI_EMBEDDING.keySet() ); - private static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED = TransportVersion.fromName( - "jina_ai_embedding_type_support_added" - ); - private final DenseVectorFieldMapper.ElementType elementType; private final String requestString; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java index 95f7b570f673a..fb99e5077050c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java @@ -47,29 +47,32 @@ public JinaAIEmbeddingsModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secrets, - ConfigurationParseContext context + ConfigurationParseContext context, + TaskType taskType ) { this( inferenceId, - JinaAIEmbeddingsServiceSettings.fromMap(serviceSettings, context), + createServiceSettings(serviceSettings, taskType, context), JinaAIEmbeddingsTaskSettings.fromMap(taskSettings), chunkingSettings, DefaultSecretSettings.fromMap(secrets), - null + null, + taskType ); } // should only be used for testing JinaAIEmbeddingsModel( String modelId, - JinaAIEmbeddingsServiceSettings serviceSettings, + BaseJinaAIEmbeddingsServiceSettings serviceSettings, JinaAIEmbeddingsTaskSettings taskSettings, ChunkingSettings chunkingSettings, @Nullable DefaultSecretSettings secretSettings, - @Nullable String uri + @Nullable String uri, + TaskType taskType ) { super( - new ModelConfigurations(modelId, TaskType.TEXT_EMBEDDING, JinaAIService.NAME, serviceSettings, taskSettings, chunkingSettings), + new ModelConfigurations(modelId, taskType, JinaAIService.NAME, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secretSettings), secretSettings, serviceSettings.getCommonSettings(), @@ -81,13 +84,13 @@ private JinaAIEmbeddingsModel(JinaAIEmbeddingsModel model, JinaAIEmbeddingsTaskS super(model, taskSettings); } - public JinaAIEmbeddingsModel(JinaAIEmbeddingsModel model, JinaAIEmbeddingsServiceSettings serviceSettings) { + public JinaAIEmbeddingsModel(JinaAIEmbeddingsModel model, BaseJinaAIEmbeddingsServiceSettings serviceSettings) { super(model, serviceSettings); } @Override - public JinaAIEmbeddingsServiceSettings getServiceSettings() { - return (JinaAIEmbeddingsServiceSettings) super.getServiceSettings(); + public BaseJinaAIEmbeddingsServiceSettings getServiceSettings() { + return (BaseJinaAIEmbeddingsServiceSettings) super.getServiceSettings(); } @Override @@ -104,4 +107,17 @@ public DefaultSecretSettings getSecretSettings() { public ExecutableAction accept(JinaAIActionVisitor visitor, Map taskSettings) { return visitor.create(this, taskSettings); } + + private static BaseJinaAIEmbeddingsServiceSettings createServiceSettings( + Map serviceSettings, + TaskType taskType, + ConfigurationParseContext context + ) { + return switch (taskType) { + case TEXT_EMBEDDING -> JinaAITextEmbeddingServiceSettings.fromMap(serviceSettings, context); + case EMBEDDING -> JinaAIEmbeddingServiceSettings.fromMap(serviceSettings, context); + // Should not be possible + default -> throw new IllegalArgumentException(); + }; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettings.java new file mode 100644 index 0000000000000..88d3e2e350d91 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettings.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.embeddings; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; + +import java.io.IOException; +import java.util.Map; + +public class JinaAITextEmbeddingServiceSettings extends BaseJinaAIEmbeddingsServiceSettings { + /** + * This name is a holdover from before the introduction of {@link JinaAIEmbeddingServiceSettings} to support multimodal embeddings + * This name cannot be changed due to backwards compatibility, but it should be 'jinaai_text_embedding_service_settings' + */ + public static final String NAME = "jinaai_embeddings_service_settings"; + public static final boolean DEFAULT_MULTIMODAL_MODEL = false; + + public static JinaAITextEmbeddingServiceSettings fromMap(Map map, ConfigurationParseContext context) { + return BaseJinaAIEmbeddingsServiceSettings.fromMap( + map, + context, + (m, v) -> DEFAULT_MULTIMODAL_MODEL, + JinaAITextEmbeddingServiceSettings::new + ); + } + + private JinaAITextEmbeddingServiceSettings( + JinaAIServiceSettings commonServiceSettings, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dims, + @Nullable Integer maxInputTokens, + @Nullable JinaAIEmbeddingType embeddingTypes, + boolean dimensionsSetByUser, + boolean multimodalModel + ) { + super(commonServiceSettings, similarity, dims, maxInputTokens, embeddingTypes, dimensionsSetByUser, DEFAULT_MULTIMODAL_MODEL); + } + + public JinaAITextEmbeddingServiceSettings( + JinaAIServiceSettings commonServiceSettings, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dims, + @Nullable Integer maxInputTokens, + @Nullable JinaAIEmbeddingType embeddingTypes, + boolean dimensionsSetByUser + ) { + this(commonServiceSettings, similarity, dims, maxInputTokens, embeddingTypes, dimensionsSetByUser, DEFAULT_MULTIMODAL_MODEL); + } + + public JinaAITextEmbeddingServiceSettings(StreamInput in) throws IOException { + super(in); + } + + @Override + public BaseJinaAIEmbeddingsServiceSettings update(SimilarityMeasure similarity, Integer dimensions) { + return new JinaAITextEmbeddingServiceSettings( + getCommonSettings(), + similarity, + dimensions, + maxInputTokens(), + getEmbeddingType(), + dimensionsSetByUser() + ); + } + + @Override + protected void optionallyWriteMultimodalField(XContentBuilder builder) { + // Do not include the multimodal_model field for text_embedding, because it is always false + } + + @Override + public String getWriteableName() { + return NAME; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequest.java index 7027533746884..e19d06e8fe08c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequest.java @@ -10,7 +10,9 @@ import org.apache.http.client.methods.HttpPost; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; @@ -23,11 +25,11 @@ public class JinaAIEmbeddingsRequest extends JinaAIRequest { - private final List input; + private final List input; private final InputType inputType; private final JinaAIEmbeddingsModel model; - public JinaAIEmbeddingsRequest(List input, InputType inputType, JinaAIEmbeddingsModel embeddingsModel) { + public JinaAIEmbeddingsRequest(List input, InputType inputType, JinaAIEmbeddingsModel embeddingsModel) { this.input = Objects.requireNonNull(input); this.inputType = inputType; this.model = Objects.requireNonNull(embeddingsModel); @@ -70,4 +72,8 @@ public boolean[] getTruncationInfo() { public JinaAIEmbeddingType getEmbeddingType() { return model.getServiceSettings().getEmbeddingType(); } + + public TaskType getTaskType() { + return model.getTaskType(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java index c68731882ffc6..a5390d58e8b4b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.jinaai.request; +import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -18,9 +19,10 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.inference.InferenceStringGroup.toStringList; import static org.elasticsearch.inference.InputType.invalidInputTypeMessage; -public record JinaAIEmbeddingsRequestEntity(List input, InputType inputType, JinaAIEmbeddingsModel model) +public record JinaAIEmbeddingsRequestEntity(List input, InputType inputType, JinaAIEmbeddingsModel model) implements ToXContentObject { @@ -29,6 +31,8 @@ public record JinaAIEmbeddingsRequestEntity(List input, InputType inputT private static final String CLUSTERING = "separation"; private static final String CLASSIFICATION = "classification"; private static final String INPUT_FIELD = "input"; + private static final String INPUT_TEXT_FIELD = "text"; + private static final String INPUT_IMAGE_FIELD = "image"; private static final String MODEL_FIELD = "model"; public static final String TASK_TYPE_FIELD = "task"; public static final String LATE_CHUNKING = "late_chunking"; @@ -46,7 +50,7 @@ public record JinaAIEmbeddingsRequestEntity(List input, InputType inputT @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(INPUT_FIELD, input); + writeInputs(builder); builder.field(MODEL_FIELD, model.getServiceSettings().modelId()); builder.field(EMBEDDING_TYPE_FIELD, model.getServiceSettings().getEmbeddingType().toRequestString()); @@ -60,7 +64,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } if (taskSettings.getLateChunking() != null) { - builder.field(LATE_CHUNKING, taskSettings.getLateChunking() && getInputWordCount() <= MAX_WORD_COUNT_FOR_LATE_CHUNKING); + builder.field( + LATE_CHUNKING, + // Late chunking is not supported for image inputs + taskSettings.getLateChunking() + && InferenceStringGroup.containsNonTextEntry(input) == false + && getInputWordCount() <= MAX_WORD_COUNT_FOR_LATE_CHUNKING + ); } if (model.getServiceSettings().dimensionsSetByUser() && model.getServiceSettings().dimensions() != null) { @@ -71,6 +81,25 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } + private void writeInputs(XContentBuilder builder) throws IOException { + if (model.getServiceSettings().isMultimodal()) { + builder.startArray(INPUT_FIELD); + for (var inferenceStringGroup : input) { + var inferenceString = inferenceStringGroup.value(); + builder.startObject(); + if (inferenceString.isText()) { + builder.field(INPUT_TEXT_FIELD, inferenceString.value()); + } else if (inferenceString.isImage()) { + builder.field(INPUT_IMAGE_FIELD, inferenceString.value()); + } + builder.endObject(); + } + builder.endArray(); + } else { + builder.field(INPUT_FIELD, toStringList(input)); + } + } + // default for testing static String convertInputType(InputType inputType) { return switch (inputType) { @@ -87,8 +116,8 @@ static String convertInputType(InputType inputType) { private int getInputWordCount() { int wordCount = 0; - for (var text : input) { - wordCount += ChunkerUtils.countWords(text); + for (var inferenceStringGroup : input) { + wordCount += ChunkerUtils.countWords(inferenceStringGroup.textValue()); } return wordCount; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java index 4471012b6e57c..ed35f50ae820b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java @@ -7,10 +7,11 @@ package org.elasticsearch.xpack.inference.services.jinaai.response; +import org.elasticsearch.common.CheckedBiFunction; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; -import org.elasticsearch.core.CheckedFunction; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; @@ -18,6 +19,10 @@ import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.EmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.EmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.GenericDenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.GenericDenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.XContentUtils; @@ -39,14 +44,15 @@ public class JinaAIEmbeddingsResponseEntity { private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI embeddings response"; - private static final Map> EMBEDDING_PARSERS = Map.of( - toLowerCase(JinaAIEmbeddingType.FLOAT), - JinaAIEmbeddingsResponseEntity::parseFloatDataObject, - toLowerCase(JinaAIEmbeddingType.BIT), - JinaAIEmbeddingsResponseEntity::parseBitDataObject, - toLowerCase(JinaAIEmbeddingType.BINARY), - JinaAIEmbeddingsResponseEntity::parseBitDataObject - ); + private static final Map> EMBEDDING_PARSERS = + Map.of( + toLowerCase(JinaAIEmbeddingType.FLOAT), + JinaAIEmbeddingsResponseEntity::parseFloatDataObject, + toLowerCase(JinaAIEmbeddingType.BIT), + JinaAIEmbeddingsResponseEntity::parseBitDataObject, + toLowerCase(JinaAIEmbeddingType.BINARY), + JinaAIEmbeddingsResponseEntity::parseBitDataObject + ); private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes(); private static String supportedEmbeddingTypes() { @@ -112,6 +118,7 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r ); } + var taskType = embeddingsRequest.getTaskType(); var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { moveToFirstToken(jsonParser); @@ -121,20 +128,26 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); - return embeddingValueParser.apply(jsonParser); + return embeddingValueParser.apply(jsonParser, taskType); } } - private static InferenceServiceResults parseFloatDataObject(XContentParser jsonParser) throws IOException { - List embeddingList = parseList( + private static InferenceServiceResults parseFloatDataObject(XContentParser jsonParser, TaskType taskType) throws IOException { + List embeddingList = parseList( jsonParser, JinaAIEmbeddingsResponseEntity::parseFloatEmbeddingObject ); - return new DenseEmbeddingFloatResults(embeddingList); + if (taskType == TaskType.TEXT_EMBEDDING) { + return new DenseEmbeddingFloatResults(embeddingList); + } else if (taskType == TaskType.EMBEDDING) { + return new GenericDenseEmbeddingFloatResults(embeddingList); + } else { + throw new IllegalArgumentException("Invalid taskType: " + taskType); + } } - private static DenseEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException { + private static EmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -143,19 +156,25 @@ private static DenseEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XC // parse and discard the rest of the object consumeUntilObjectEnd(parser); - return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); + return EmbeddingFloatResults.Embedding.of(embeddingValuesList); } - private static InferenceServiceResults parseBitDataObject(XContentParser jsonParser) throws IOException { + private static InferenceServiceResults parseBitDataObject(XContentParser jsonParser, TaskType taskType) throws IOException { List embeddingList = parseList( jsonParser, JinaAIEmbeddingsResponseEntity::parseBitEmbeddingObject ); - return new DenseEmbeddingBitResults(embeddingList); + if (taskType == TaskType.TEXT_EMBEDDING) { + return new DenseEmbeddingBitResults(embeddingList); + } else if (taskType == TaskType.EMBEDDING) { + return new GenericDenseEmbeddingBitResults(embeddingList); + } else { + throw new IllegalArgumentException("Invalid taskType: " + taskType); + } } - private static DenseEmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException { + private static EmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -164,7 +183,7 @@ private static DenseEmbeddingByteResults.Embedding parseBitEmbeddingObject(XCont // parse and discard the rest of the object consumeUntilObjectEnd(parser); - return DenseEmbeddingByteResults.Embedding.of(embeddingList); + return EmbeddingByteResults.Embedding.of(embeddingList); } private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidator.java index 9140795d3d796..f4d2610688922 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidator.java @@ -23,6 +23,7 @@ import org.elasticsearch.rest.RestStatus; import java.util.List; +import java.util.Map; public class SimpleEmbeddingServiceIntegrationValidator implements ServiceIntegrationValidator { // The below data URI represents the base64 encoding of 28x28 pixel black square .jpg image @@ -44,7 +45,13 @@ public class SimpleEmbeddingServiceIntegrationValidator implements ServiceIntegr @Override public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { - EmbeddingRequest request = new EmbeddingRequest(List.of(TEST_TEXT_INPUT, TEST_IMAGE_BASE64_INPUT), InputType.INTERNAL_INGEST); + List inputList; + if (model.getServiceSettings().isMultimodal()) { + inputList = List.of(TEST_TEXT_INPUT, TEST_IMAGE_BASE64_INPUT); + } else { + inputList = List.of(TEST_TEXT_INPUT); + } + EmbeddingRequest request = new EmbeddingRequest(inputList, InputType.INTERNAL_INGEST, Map.of()); service.embeddingInfer(model, request, timeout, ActionListener.wrap(r -> { if (r != null) { listener.onResponse(r); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java index 9f2055d825fd7..4da01e0aca8af 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java @@ -44,6 +44,14 @@ public static InputType randomWithNull() { ); } + public static InputType randomRequestType() { + return randomFrom(InputType.INGEST, InputType.SEARCH, InputType.CLUSTERING, InputType.CLASSIFICATION); + } + + public static InputType randomRequestTypeWithNull() { + return randomFrom(InputType.INGEST, InputType.SEARCH, InputType.CLUSTERING, InputType.CLASSIFICATION, null); + } + public static InputType randomSearchAndIngestWithNull() { return randomBoolean() ? null diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/TaskTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/TaskTypeTests.java index f6c058bdbb79f..4d8244c22c83c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/TaskTypeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/TaskTypeTests.java @@ -24,4 +24,7 @@ public void testFromStringOrStatusException() { assertThat(TaskType.fromStringOrStatusException("any"), Matchers.is(TaskType.ANY)); } + public static TaskType randomEmbeddingTaskType() { + return randomFrom(TaskType.TEXT_EMBEDDING, TaskType.EMBEDDING); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index cfd9216e17a07..c232f92bc50f8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.Model; @@ -155,12 +156,19 @@ public record PersistedConfig(Map config, Map se public static PersistedConfig getPersistedConfigMap( Map serviceSettings, Map taskSettings, - Map chunkingSettings, - Map secretSettings + @Nullable Map chunkingSettings, + @Nullable Map secretSettings ) { + var secrets = secretSettings == null ? null : new HashMap(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)); + + var persistedConfigMap = new PersistedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + secrets + ); - var persistedConfigMap = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - persistedConfigMap.config.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + if (chunkingSettings != null) { + persistedConfigMap.config.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + } return persistedConfigMap; } @@ -168,27 +176,19 @@ public static PersistedConfig getPersistedConfigMap( public static PersistedConfig getPersistedConfigMap( Map serviceSettings, Map taskSettings, - Map secretSettings + @Nullable Map secretSettings ) { - var secrets = secretSettings == null ? null : new HashMap(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)); + return getPersistedConfigMap(serviceSettings, taskSettings, null, secretSettings); + } - return new PersistedConfig( - new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), - secrets - ); + public static PersistedConfig getPersistedConfigMap(Map serviceSettings, Map taskSettings) { + return Utils.getPersistedConfigMap(serviceSettings, taskSettings, null); } public static PersistedConfig getPersistedConfigMap(Map serviceSettings) { return Utils.getPersistedConfigMap(serviceSettings, new HashMap<>(), null); } - public static PersistedConfig getPersistedConfigMap(Map serviceSettings, Map taskSettings) { - return new PersistedConfig( - new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), - null - ); - } - public static Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettingsTests.java index 3e89c0c0425d9..fdec51a1a57d0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettingsTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; @@ -23,6 +24,7 @@ import java.util.HashMap; import java.util.Map; +import static org.elasticsearch.xpack.inference.services.settings.RateLimitSettings.REQUESTS_PER_MINUTE_FIELD; import static org.hamcrest.Matchers.is; public class JinaAIServiceSettingsTests extends AbstractBWCWireSerializationTestCase { @@ -104,11 +106,15 @@ protected JinaAIServiceSettings mutateInstance(JinaAIServiceSettings instance) t return new JinaAIServiceSettings(modelId, rateLimitSettings); } - public static Map getServiceSettingsMap(String model) { + public static Map getServiceSettingsMap(String model, @Nullable Integer requestsPerMinute) { var map = new HashMap(); map.put(ServiceFields.MODEL_ID, model); + if (requestsPerMinute != null) { + map.put(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(REQUESTS_PER_MINUTE_FIELD, requestsPerMinute))); + } + return map; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index 4670d32b8db61..8095d4ccb0f48 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -18,27 +18,37 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; +import org.elasticsearch.inference.EmbeddingRequest; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InferenceString; +import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsOptions; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.EmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.GenericDenseEmbeddingFloatResultsTests; +import org.elasticsearch.xpack.inference.InputTypeTests; +import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -47,10 +57,14 @@ import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests; -import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettingsTests; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModelTests; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.CoreMatchers; import org.hamcrest.Matchers; @@ -58,26 +72,41 @@ import org.junit.Before; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.TimeUnit; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.inference.InferenceString.DataFormat.BASE64; +import static org.elasticsearch.inference.InferenceString.DataType.IMAGE; +import static org.elasticsearch.inference.InferenceString.DataType.TEXT; +import static org.elasticsearch.inference.ModelConfigurations.SERVICE_SETTINGS; +import static org.elasticsearch.inference.TaskType.EMBEDDING; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.DEFAULT_SETTINGS; +import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.OLD_DEFAULT_SETTINGS; import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; -import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.inference.TaskTypeTests.randomEmbeddingTaskType; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.Utils.randomSimilarityMeasure; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MULTIMODAL_MODEL; +import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS; +import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettingsTests.getServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettingsTests.getMapOfCommonEmbeddingSettings; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettingsTests.getMapOfMinimalEmbeddingSettings; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.sameInstance; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -92,8 +121,8 @@ @SuppressWarnings("resource") public class JinaAIServiceTests extends InferenceServiceTestCase { - private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); public static final String DEFAULT_EMBEDDING_URL = "https://api.jina.ai/v1/embeddings"; + public static final String DEFAULT_RERANK_URL = "https://api.jina.ai/v1/rerank"; private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; private HttpClientManager clientManager; @@ -112,188 +141,273 @@ public void shutdown() throws IOException { webServer.close(); } - public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModel() throws IOException { + public void testParseRequestConfig_createsEmbeddingsModel_textEmbeddingTask() throws IOException { + testParseRequestConfig_createsEmbeddingModel(TEXT_EMBEDDING); + } + + public void testParseRequestConfig_createsEmbeddingsModel_embeddingTask() throws IOException { + testParseRequestConfig_createsEmbeddingModel(EMBEDDING); + } + + private void testParseRequestConfig_createsEmbeddingModel(TaskType taskType) throws IOException { try (var service = createJinaAIService()) { - ActionListener modelListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var modelName = randomAlphanumericOfLength(8); + var requestsPerMinute = randomNonNegativeInt(); + var similarity = randomSimilarityMeasure(); + var dimensions = randomNonNegativeInt(); + var maxInputTokens = randomNonNegativeInt(); + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var multimodalModel = taskType == EMBEDDING && randomBoolean(); + var inputType = InputTypeTests.randomRequestType(); + var lateChunking = randomBoolean(); + var apiKey = randomAlphanumericOfLength(8); + var chunkingSettings = createRandomChunkingSettings(); - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(JinaAIEmbeddingType.FLOAT)); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + var serviceSettingsMap = getMapOfCommonEmbeddingSettings( + modelName, + similarity, + dimensions, + null, + maxInputTokens, + embeddingType, + requestsPerMinute + ); + + if (taskType == EMBEDDING) { + serviceSettingsMap.put(MULTIMODAL_MODEL, multimodalModel); + } + + var modelListener = new PlainActionFuture(); service.parseRequestConfig( "id", - TaskType.TEXT_EMBEDDING, + taskType, getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), - getSecretSettingsMap("secret") + serviceSettingsMap, + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(inputType, lateChunking), + chunkingSettings.asMap(), + getSecretSettingsMap(apiKey) ), modelListener ); + assertEmbeddingModelSettings( + modelListener.actionGet(), + modelName, + new RateLimitSettings(requestsPerMinute), + similarity, + dimensions, + true, + maxInputTokens, + embeddingType, + multimodalModel, + new JinaAIEmbeddingsTaskSettings(inputType, lateChunking), + chunkingSettings, + apiKey + ); } } - public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + public void testParseRequestConfig_createsRerankModel() throws IOException { try (var service = createJinaAIService()) { - ActionListener modelListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(JinaAIEmbeddingType.FLOAT)); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + var modelName = randomAlphanumericOfLength(8); + var requestsPerMinute = randomNonNegativeInt(); + var topN = randomNonNegativeInt(); + var returnDocuments = randomBoolean(); + var apiKey = randomAlphanumericOfLength(8); + + var modelListener = new PlainActionFuture(); service.parseRequestConfig( "id", - TaskType.TEXT_EMBEDDING, + TaskType.RERANK, getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), - createRandomChunkingSettingsMap(), - getSecretSettingsMap("secret") + JinaAIRerankServiceSettingsTests.getServiceSettingsMap(modelName, requestsPerMinute), + JinaAIRerankTaskSettingsTests.getTaskSettingsMap(topN, returnDocuments), + getSecretSettingsMap(apiKey) ), modelListener ); + assertRerankModelSettings( + modelListener.actionGet(), + modelName, + new RateLimitSettings(requestsPerMinute), + apiKey, + new JinaAIRerankTaskSettings(topN, returnDocuments) + ); } } - public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + public void testParseRequestConfig_onlyRequiredSettings_createsEmbeddingModel_textEmbedding() throws IOException { + testParseRequestConfig_onlyRequiredSettings_createsEmbeddingModel(TEXT_EMBEDDING); + } + + public void testParseRequestConfig_onlyRequiredSettings_createsEmbeddingModel_embedding() throws IOException { + testParseRequestConfig_onlyRequiredSettings_createsEmbeddingModel(EMBEDDING); + } + + private void testParseRequestConfig_onlyRequiredSettings_createsEmbeddingModel(TaskType taskType) throws IOException { try (var service = createJinaAIService()) { - ActionListener modelListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(JinaAIEmbeddingType.BIT)); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + var modelName = randomAlphanumericOfLength(8); + var apiKey = randomAlphanumericOfLength(8); + + var modelListener = new PlainActionFuture(); service.parseRequestConfig( "id", - TaskType.TEXT_EMBEDDING, + taskType, getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.BIT), - JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), - getSecretSettingsMap("secret") + getMapOfCommonEmbeddingSettings(modelName, null, null, null, null, null, null), + getSecretSettingsMap(apiKey) ), modelListener ); + assertEmbeddingModelSettings( + modelListener.actionGet(), + modelName, + DEFAULT_RATE_LIMIT_SETTINGS, + null, + null, + false, + null, + JinaAIEmbeddingType.FLOAT, + taskType == EMBEDDING, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + DEFAULT_SETTINGS, + apiKey + ); } } - public void testParseRequestConfig_OptionalTaskSettings() throws IOException { + public void testParseRequestConfig_onlyRequiredSettings_createsRerankModel() throws IOException { try (var service = createJinaAIService()) { + var modelName = randomAlphanumericOfLength(8); + var apiKey = randomAlphanumericOfLength(8); - ActionListener modelListener = ActionListener.wrap(model -> { - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), equalTo(JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + var modelListener = new PlainActionFuture(); service.parseRequestConfig( "id", - TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - getSecretSettingsMap("secret") - ), + TaskType.RERANK, + getRequestConfigMap(JinaAIRerankServiceSettingsTests.getServiceSettingsMap(modelName), getSecretSettingsMap(apiKey)), modelListener ); + assertRerankModelSettings( + modelListener.actionGet(), + modelName, + DEFAULT_RATE_LIMIT_SETTINGS, + apiKey, + JinaAIRerankTaskSettings.EMPTY_SETTINGS + ); + } } - public void testParseRequestConfig_ThrowsUnsupportedTaskType() throws IOException { + public void testParseRequestConfig_ThrowsErrorWithUnsupportedTaskType() throws IOException { try (var service = createJinaAIService()) { - var failureListener = getModelListenerForStatusException("The [jinaai] service does not support task type [sparse_embedding]"); + var unsupportedTaskType = randomValueOtherThanMany( + t -> service.supportedTaskTypes().contains(t), + () -> randomFrom(TaskType.values()) + ); + var failureListener = getModelListenerForStatusException( + Strings.format("The [jinaai] service does not support task type [%s]", unsupportedTaskType) + ); service.parseRequestConfig( "id", - TaskType.SPARSE_EMBEDDING, - getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - getSecretSettingsMap("secret") - ), + unsupportedTaskType, + getRequestConfigMap(getMapOfMinimalEmbeddingSettings("model"), getSecretSettingsMap("secret")), failureListener ); } } - private static ActionListener getModelListenerForStatusException(String expectedMessage) { - return ActionListener.wrap((model) -> fail("Model parsing should have failed"), e -> { - assertThat(e, instanceOf(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is(expectedMessage)); - }); - } - public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createJinaAIService()) { - var config = getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - getSecretSettingsMap("secret") - ); + var config = getRequestConfigMap(getMapOfMinimalEmbeddingSettings("model"), getSecretSettingsMap("secret")); config.put("extra_key", "value"); var failureListener = getModelListenerForStatusException( "Configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + service.parseRequestConfig("id", randomFrom(service.supportedTaskTypes()), config, failureListener); } } public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { try (var service = createJinaAIService()) { - var serviceSettings = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT); + var serviceSettings = getMapOfMinimalEmbeddingSettings("model"); serviceSettings.put("extra_key", "value"); - var config = getRequestConfigMap(serviceSettings, Map.of(), getSecretSettingsMap("secret")); + var config = getRequestConfigMap(serviceSettings, getSecretSettingsMap("secret")); var failureListener = getModelListenerForStatusException( "Configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + service.parseRequestConfig("id", randomFrom(service.supportedTaskTypes()), config, failureListener); + } + } + + public void testParseRequestConfig_textEmbedding_throwsWhenMultimodalModelKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createJinaAIService()) { + var serviceSettings = getMapOfMinimalEmbeddingSettings("model"); + serviceSettings.put(MULTIMODAL_MODEL, true); + + var config = getRequestConfigMap(serviceSettings, getSecretSettingsMap("secret")); + + var failureListener = getModelListenerForStatusException( + "Configuration contains settings [{multimodal_model=true}] unknown to the [jinaai] service" + ); + service.parseRequestConfig("id", TEXT_EMBEDDING, config, failureListener); + } + } + + public void testParseRequestConfig_embedding_doesNotThrowWhenMultimodalModelKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createJinaAIService()) { + String modelName = "model"; + var serviceSettings = getMapOfMinimalEmbeddingSettings(modelName); + var multimodalModel = randomBoolean(); + serviceSettings.put(MULTIMODAL_MODEL, multimodalModel); + + String apiKey = "secret"; + var config = getRequestConfigMap(serviceSettings, getSecretSettingsMap(apiKey)); + + var modelListener = new PlainActionFuture(); + service.parseRequestConfig("id", EMBEDDING, config, modelListener); + + assertEmbeddingModelSettings( + modelListener.actionGet(), + modelName, + DEFAULT_RATE_LIMIT_SETTINGS, + null, + null, + false, + null, + JinaAIEmbeddingType.FLOAT, + multimodalModel, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + DEFAULT_SETTINGS, + apiKey + ); } } public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { try (var service = createJinaAIService()) { - var taskSettingsMap = JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST); - taskSettingsMap.put("extra_key", "value"); var config = getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - taskSettingsMap, + getMapOfMinimalEmbeddingSettings("model"), + new HashMap<>(Map.of("extra_key", "value")), getSecretSettingsMap("secret") ); var failureListener = getModelListenerForStatusException( "Configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); - + service.parseRequestConfig("id", randomFrom(service.supportedTaskTypes()), config, failureListener); } } @@ -302,439 +416,440 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var secretSettingsMap = getSecretSettingsMap("secret"); secretSettingsMap.put("extra_key", "value"); - var config = getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - secretSettingsMap - ); + var config = getRequestConfigMap(getMapOfMinimalEmbeddingSettings("model"), secretSettingsMap); var failureListener = getModelListenerForStatusException( "Configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + service.parseRequestConfig("id", randomFrom(service.supportedTaskTypes()), config, failureListener); } } - public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModel() throws IOException { + public void testParsePersistedConfigWithSecrets_createsEmbeddingsModel_textEmbedding() throws IOException { + testParsePersistedConfigWithSecrets_createsEmbeddingModel(TEXT_EMBEDDING); + } + + public void testParsePersistedConfigWithSecrets_createsEmbeddingsModel_embedding() throws IOException { + testParsePersistedConfigWithSecrets_createsEmbeddingModel(EMBEDDING); + } + + private void testParsePersistedConfigWithSecrets_createsEmbeddingModel(TaskType taskType) throws IOException { try (var service = createJinaAIService()) { - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - getSecretSettingsMap("secret") + var modelName = randomAlphanumericOfLength(8); + var requestsPerMinute = randomNonNegativeInt(); + var similarity = randomSimilarityMeasure(); + var dimensions = randomNonNegativeInt(); + var dimensionsSetByUser = randomBoolean(); + var maxInputTokens = randomNonNegativeInt(); + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var multimodalModel = taskType == EMBEDDING && randomBoolean(); + var inputType = InputTypeTests.randomRequestType(); + var lateChunking = randomBoolean(); + var apiKey = randomAlphanumericOfLength(8); + var chunkingSettings = createRandomChunkingSettings(); + + var serviceSettingsMap = getMapOfCommonEmbeddingSettings( + modelName, + similarity, + dimensions, + dimensionsSetByUser, + maxInputTokens, + embeddingType, + requestsPerMinute ); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + if (taskType == EMBEDDING) { + serviceSettingsMap.put(MULTIMODAL_MODEL, multimodalModel); + } + + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(inputType, lateChunking), + chunkingSettings.asMap(), + getSecretSettingsMap(apiKey) ); - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var model = service.parsePersistedConfigWithSecrets("id", taskType, persistedConfig.config(), persistedConfig.secrets()); - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertEmbeddingModelSettings( + model, + modelName, + new RateLimitSettings(requestsPerMinute), + similarity, + dimensions, + dimensionsSetByUser, + maxInputTokens, + embeddingType, + multimodalModel, + new JinaAIEmbeddingsTaskSettings(inputType, lateChunking), + chunkingSettings, + apiKey + ); } } - public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + public void testParsePersistedConfigWithSecrets_createsRerankModel() throws IOException { try (var service = createJinaAIService()) { + var modelName = randomAlphanumericOfLength(8); + var requestsPerMinute = randomNonNegativeInt(); + var topN = randomNonNegativeInt(); + var returnDocuments = randomBoolean(); + var apiKey = randomAlphanumericOfLength(8); + var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - createRandomChunkingSettingsMap(), - getSecretSettingsMap("secret") + JinaAIRerankServiceSettingsTests.getServiceSettingsMap(modelName, requestsPerMinute), + JinaAIRerankTaskSettingsTests.getTaskSettingsMap(topN, returnDocuments), + getSecretSettingsMap(apiKey) ); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, persistedConfig.config(), persistedConfig.secrets()); + + assertRerankModelSettings( + model, + modelName, + new RateLimitSettings(requestsPerMinute), + apiKey, + new JinaAIRerankTaskSettings(topN, returnDocuments) ); + } + } - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + public void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmbeddingsModel_textEmbedding() throws IOException { + testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmbeddingModel(TEXT_EMBEDDING); + } - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); - } + public void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmbeddingsModel_embedding() throws IOException { + testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmbeddingModel(EMBEDDING); } - public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + private void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmbeddingModel(TaskType taskType) throws IOException { try (var service = createJinaAIService()) { + var modelName = randomAlphanumericOfLength(8); + var apiKey = randomAlphanumericOfLength(8); + Map chunkingSettingsMap = randomBoolean() ? Map.of() : null; + var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), + getMapOfMinimalEmbeddingSettings(modelName), Map.of(), - getSecretSettingsMap("secret") + chunkingSettingsMap, + getSecretSettingsMap(apiKey) ); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfigWithSecrets("id", taskType, persistedConfig.config(), persistedConfig.secrets()); + + assertEmbeddingModelSettings( + model, + modelName, + DEFAULT_RATE_LIMIT_SETTINGS, + null, + null, + false, + null, + JinaAIEmbeddingType.FLOAT, + taskType == EMBEDDING, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + chunkingSettingsMap == null ? OLD_DEFAULT_SETTINGS : DEFAULT_SETTINGS, + apiKey ); + } + } + + public void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsRerankModel() throws IOException { + try (var service = createJinaAIService()) { + var modelName = randomAlphanumericOfLength(8); + var apiKey = randomAlphanumericOfLength(8); - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(modelName, null), Map.of(), getSecretSettingsMap(apiKey)); - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, persistedConfig.config(), persistedConfig.secrets()); + + assertRerankModelSettings(model, modelName, DEFAULT_RATE_LIMIT_SETTINGS, apiKey, JinaAIRerankTaskSettings.EMPTY_SETTINGS); } } - public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + public void testParsePersistedConfigWithSecrets_ThrowsErrorWithUnsupportedTaskType() throws IOException { try (var service = createJinaAIService()) { - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("oldmodel", JinaAIEmbeddingType.FLOAT), - Map.of(), - getSecretSettingsMap("secret") + var unsupportedTaskType = randomValueOtherThanMany( + t -> service.supportedTaskTypes().contains(t), + () -> randomFrom(TaskType.values()) ); + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("oldmodel", null), Map.of(), getSecretSettingsMap("secret")); var thrownException = expectThrows( ElasticsearchStatusException.class, () -> service.parsePersistedConfigWithSecrets( "id", - TaskType.SPARSE_EMBEDDING, + unsupportedTaskType, persistedConfig.config(), persistedConfig.secrets() ) ); assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [jinaai] service")); - assertThat(thrownException.getMessage(), containsString("The [jinaai] service does not support task type [sparse_embedding]")); + assertThat( + thrownException.getMessage(), + containsString(Strings.format("The [jinaai] service does not support task type [%s]", unsupportedTaskType)) + ); } } public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createJinaAIService()) { + String modelName = "modelName"; + String apiKey = "secret"; var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), + getServiceSettingsMap(modelName, null), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH), - getSecretSettingsMap("secret") + getSecretSettingsMap(apiKey) ); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.SEARCH))); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertParsePersistedConfigWithSecretsMinimalSettings(service, persistedConfig, modelName, apiKey); } } public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { try (var service = createJinaAIService()) { - var secretSettingsMap = getSecretSettingsMap("secret"); + String modelName = "modelName"; + String apiKey = "secret"; + var secretSettingsMap = getSecretSettingsMap(apiKey); secretSettingsMap.put("extra_key", "value"); - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - secretSettingsMap - ); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); - - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(modelName, null), Map.of(), secretSettingsMap); - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertParsePersistedConfigWithSecretsMinimalSettings(service, persistedConfig, modelName, apiKey); } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createJinaAIService()) { - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - getSecretSettingsMap("secret") - ); - persistedConfig.secrets().put("extra_key", "value"); - - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() - ); + String modelName = "modelName"; + String apiKey = "secret"; + var serviceSettingsMap = getServiceSettingsMap(modelName, null); + serviceSettingsMap.put("extra_key", "value"); - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var persistedConfig = getPersistedConfigMap(serviceSettingsMap, Map.of(), getSecretSettingsMap("secret")); - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertParsePersistedConfigWithSecretsMinimalSettings(service, persistedConfig, modelName, apiKey); } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createJinaAIService()) { - var serviceSettingsMap = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT); - serviceSettingsMap.put("extra_key", "value"); - - var persistedConfig = getPersistedConfigMap(serviceSettingsMap, Map.of(), getSecretSettingsMap("secret")); + String modelName = "modelName"; + String apiKey = "secret"; - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap(modelName, null), + new HashMap<>(Map.of("extra_key", "value")), + getSecretSettingsMap(apiKey) ); - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertParsePersistedConfigWithSecretsMinimalSettings(service, persistedConfig, modelName, apiKey); } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInChunkingSettings() throws IOException { try (var service = createJinaAIService()) { - var taskSettingsMap = JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH); - taskSettingsMap.put("extra_key", "value"); + String modelName = "modelName"; + String apiKey = "secret"; var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - taskSettingsMap, - getSecretSettingsMap("secret") + getServiceSettingsMap(modelName, null), + Map.of(), + Map.of(ChunkingSettingsOptions.STRATEGY.toString(), ChunkingStrategy.NONE.toString(), "extra_key", "value"), + getSecretSettingsMap(apiKey) ); var model = service.parsePersistedConfigWithSecrets( "id", - TaskType.TEXT_EMBEDDING, + randomEmbeddingTaskType(), persistedConfig.config(), persistedConfig.secrets() ); - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.SEARCH))); - assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(model.getServiceSettings().modelId(), is(modelName)); + assertThat(model.apiKey().toString(), is(apiKey)); } } - public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModel() throws IOException { - try (var service = createJinaAIService()) { - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of() - ); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + public void testParsePersistedConfig_createsEmbeddingsModel_textEmbedding() throws IOException { + testParsePersistedConfig_createsEmbeddingModel(TEXT_EMBEDDING); + } - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); - assertNull(embeddingsModel.getSecretSettings()); - } + public void testParsePersistedConfig_createsEmbeddingsModel_embedding() throws IOException { + testParsePersistedConfig_createsEmbeddingModel(EMBEDDING); } - public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + private void testParsePersistedConfig_createsEmbeddingModel(TaskType taskType) throws IOException { try (var service = createJinaAIService()) { - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of(), - createRandomChunkingSettingsMap() + var modelName = randomAlphanumericOfLength(8); + var requestsPerMinute = randomNonNegativeInt(); + var similarity = randomSimilarityMeasure(); + var dimensions = randomNonNegativeInt(); + var dimensionsSetByUser = randomBoolean(); + var maxInputTokens = randomNonNegativeInt(); + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var multimodalModel = taskType == EMBEDDING && randomBoolean(); + var inputType = InputTypeTests.randomRequestType(); + var lateChunking = randomBoolean(); + var chunkingSettings = createRandomChunkingSettings(); + + var serviceSettingsMap = getMapOfCommonEmbeddingSettings( + modelName, + similarity, + dimensions, + dimensionsSetByUser, + maxInputTokens, + embeddingType, + requestsPerMinute ); - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + if (taskType == EMBEDDING) { + serviceSettingsMap.put(MULTIMODAL_MODEL, multimodalModel); + } - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(inputType, lateChunking), + chunkingSettings.asMap(), + null + ); + + var model = service.parsePersistedConfig("id", taskType, persistedConfig.config()); - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertNull(embeddingsModel.getSecretSettings()); + assertEmbeddingModelSettings( + model, + modelName, + new RateLimitSettings(requestsPerMinute), + similarity, + dimensions, + dimensionsSetByUser, + maxInputTokens, + embeddingType, + multimodalModel, + new JinaAIEmbeddingsTaskSettings(inputType, lateChunking), + chunkingSettings, + "" + ); } } - public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + public void testParsePersistedConfig_createsRerankModel() throws IOException { try (var service = createJinaAIService()) { + var modelName = randomAlphanumericOfLength(8); + var requestsPerMinute = randomNonNegativeInt(); + var topN = randomNonNegativeInt(); + var returnDocuments = randomBoolean(); + var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of() + JinaAIRerankServiceSettingsTests.getServiceSettingsMap(modelName, requestsPerMinute), + JinaAIRerankTaskSettingsTests.getTaskSettingsMap(topN, returnDocuments), + null ); - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var model = service.parsePersistedConfig("id", TaskType.RERANK, persistedConfig.config()); - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); - assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - assertNull(embeddingsModel.getSecretSettings()); + assertRerankModelSettings( + model, + modelName, + new RateLimitSettings(requestsPerMinute), + "", + new JinaAIRerankTaskSettings(topN, returnDocuments) + ); } } - public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException { + public void testParsePersistedConfig_ThrowsErrorWithUnsupportedTaskType() throws IOException { try (var service = createJinaAIService()) { - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model_old", JinaAIEmbeddingType.FLOAT), - Map.of() + var unsupportedTaskType = randomValueOtherThanMany( + t -> service.supportedTaskTypes().contains(t), + () -> randomFrom(TaskType.values()) ); + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model_old", null)); var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) + () -> service.parsePersistedConfig("id", unsupportedTaskType, persistedConfig.config()) ); assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [jinaai] service")); - assertThat(thrownException.getMessage(), containsString("The [jinaai] service does not support task type [sparse_embedding]")); + assertThat( + thrownException.getMessage(), + containsString(Strings.format("The [jinaai] service does not support task type [%s]", unsupportedTaskType)) + ); } } public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createJinaAIService()) { - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - Map.of() - ); + String modelName = "modelName"; + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(modelName, null)); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); - assertNull(embeddingsModel.getSecretSettings()); + assertParsePersistedConfigMinimalSettings(service, persistedConfig, modelName); } } public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createJinaAIService()) { - var serviceSettingsMap = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT); + String modelName = "modelName"; + var serviceSettingsMap = getServiceSettingsMap(modelName, null); serviceSettingsMap.put("extra_key", "value"); - var persistedConfig = getPersistedConfigMap( - serviceSettingsMap, - JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH) - ); - - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var persistedConfig = getPersistedConfigMap(serviceSettingsMap); - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); - - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.SEARCH))); - assertNull(embeddingsModel.getSecretSettings()); + assertParsePersistedConfigMinimalSettings(service, persistedConfig, modelName); } } public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createJinaAIService()) { - var taskSettingsMap = JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST); - taskSettingsMap.put("extra_key", "value"); - - var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model", JinaAIEmbeddingType.FLOAT), - taskSettingsMap - ); + String modelName = "modelName"; + var taskSettingsMap = new HashMap(Map.of("extra_key", "value")); - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); - assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(modelName, null), taskSettingsMap); - var embeddingsModel = (JinaAIEmbeddingsModel) model; - assertThat(embeddingsModel.uri().toString(), is(DEFAULT_EMBEDDING_URL)); - assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); - assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); - assertNull(embeddingsModel.getSecretSettings()); + assertParsePersistedConfigMinimalSettings(service, persistedConfig, modelName); } } - public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException { - var sender = createMockSender(); - - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - - var mockModel = getInvalidModel("model_id", "service_name"); + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInChunkingSettings() throws IOException { + try (var service = createJinaAIService()) { + String modelName = "modelName"; - try (var service = new JinaAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - PlainActionFuture listener = new PlainActionFuture<>(); - service.infer( - mockModel, - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap(modelName, null), + Map.of(), + Map.of(ChunkingSettingsOptions.STRATEGY.toString(), ChunkingStrategy.NONE.toString(), "extra_key", "value"), + null ); - var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat( - thrownException.getMessage(), - is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") - ); + var model = service.parsePersistedConfig("id", randomEmbeddingTaskType(), persistedConfig.config()); - verify(factory, times(1)).createSender(); - verify(sender, times(1)).startAsynchronously(any()); + assertThat(model.getServiceSettings().modelId(), is(modelName)); + assertThat(model.apiKey().toString(), is("")); } - - verify(sender, times(1)).close(); - verifyNoMoreInteractions(factory); - verifyNoMoreInteractions(sender); } public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { - testUpdateModelWithEmbeddingDetails_Successful(null); + testUpdateModelWithEmbeddingDetails_Successful(null, 128); } public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException { - testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values())); + testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()), 128); } - private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { + public void testUpdateModelWithEmbeddingDetails_NullDimensionsInOriginalModel() throws IOException { + testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()), null); + } + + public void testUpdateModelWithEmbeddingDetails_NullSimilarityAndDimensionsInOriginalModel() throws IOException { + testUpdateModelWithEmbeddingDetails_Successful(null, null); + } + + private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure, Integer dimensions) + throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { @@ -745,13 +860,15 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si randomAlphaOfLength(10), RateLimitSettingsTests.createRandom(), similarityMeasure, - randomNonNegativeInt(), + dimensions, randomNonNegativeInt(), embeddingType, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, null, randomAlphaOfLength(10), - false + false, + randomEmbeddingTaskType(), + randomBoolean() ); Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); @@ -764,7 +881,71 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si } } - public void testInfer_Embedding_UnauthorisedResponse() throws IOException { + public void testUpdateModelWithEmbeddingDetails_returnsExistingModelIfSettingsUnchanged() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var similarityMeasure = randomSimilarityMeasure(); + var dimensions = randomNonNegativeInt(); + var model = JinaAIEmbeddingsModelTests.createModel( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + RateLimitSettingsTests.createRandom(), + similarityMeasure, + dimensions, + randomNonNegativeInt(), + randomFrom(JinaAIEmbeddingType.values()), + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + null, + randomAlphaOfLength(10), + false, + randomEmbeddingTaskType(), + randomBoolean() + ); + + assertThat(service.updateModelWithEmbeddingDetails(model, dimensions), sameInstance(model)); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException { + var sender = createMockSender(); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name"); + + try (var service = new JinaAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).startAsynchronously(any()); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testInfer_TextEmbedding_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { @@ -776,7 +957,7 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); - var model = JinaAIEmbeddingsModelTests.createModel(getUrl(webServer), "model", "secret"); + var model = JinaAIEmbeddingsModelTests.createTextEmbeddingModel(getUrl(webServer), "model", "secret"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -791,7 +972,7 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { listener ); - var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); assertThat(error.getMessage(), containsString("Error message: [Unauthorized]")); assertThat(webServer.requests(), hasSize(1)); @@ -825,7 +1006,7 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { listener ); - var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); assertThat(error.getMessage(), containsString("Error message: [Unauthorized]")); assertThat(webServer.requests(), hasSize(1)); @@ -833,94 +1014,35 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { } public void testInfer_TextEmbedding_Get_Response_Ingest() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - - String responseJson = """ - { - "model": "jina-clip-v2", - "object": "list", - "usage": { - "total_tokens": 5, - "prompt_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testInfer_TextEmbedding_Get_Response(randomFrom(InputType.INGEST, InputType.INTERNAL_INGEST), "retrieval.passage"); + } - String apiKey = "apiKey"; - int dimensions = 1024; - String modelName = "jina-clip-v2"; - var model = JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - modelName, - JinaAIEmbeddingType.FLOAT, - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - apiKey, - dimensions - ); - PlainActionFuture listener = new PlainActionFuture<>(); - List input = List.of("abc"); - service.infer( - model, - null, - null, - null, - input, - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + public void testInfer_TextEmbedding_Get_Response_Search() throws IOException { + testInfer_TextEmbedding_Get_Response(randomFrom(InputType.SEARCH, InputType.INTERNAL_SEARCH), "retrieval.query"); + } - var result = listener.actionGet(TIMEOUT); + public void testInfer_TextEmbedding_Get_Response_clustering() throws IOException { + testInfer_TextEmbedding_Get_Response(InputType.CLUSTERING, "separation"); + } - assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); + public void testInfer_TextEmbedding_Get_Response_classification() throws IOException { + testInfer_TextEmbedding_Get_Response(InputType.CLASSIFICATION, "classification"); + } - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer " + apiKey)); + public void testInfer_TextEmbedding_Get_Response_unspecified() throws IOException { + testInfer_TextEmbedding_Get_Response(InputType.UNSPECIFIED, null); + } - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat( - requestMap, - is( - Map.of( - "input", - input, - "model", - modelName, - "task", - "retrieval.passage", - "embedding_type", - "float", - "dimensions", - dimensions - ) - ) - ); - } + public void testInfer_TextEmbedding_Get_Response_NullInputType() throws IOException { + testInfer_TextEmbedding_Get_Response(null, null); } - public void testInfer_TextEmbedding_Get_Response_Search() throws IOException { + private void testInfer_TextEmbedding_Get_Response(InputType inputType, String expectedJinaTask) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - String responseJson = """ + var responseJson = """ { "model": "jina-clip-v2", "object": "list", @@ -940,85 +1062,21 @@ public void testInfer_TextEmbedding_Get_Response_Search() throws IOException { ] } """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - String apiKey = "apiKey"; - int dimensions = 1024; - String modelName = "jina-clip-v2"; - var model = JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - modelName, - JinaAIEmbeddingType.FLOAT, - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - apiKey, - dimensions - ); - PlainActionFuture listener = new PlainActionFuture<>(); - List input = List.of("abc"); - service.infer( - model, - null, - null, - null, - input, - false, - new HashMap<>(), - InputType.SEARCH, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - - var result = listener.actionGet(TIMEOUT); - - assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); - - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer " + apiKey)); - - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat( - requestMap, - is( - Map.of( - "input", - input, - "model", - modelName, - "task", - "retrieval.query", - "embedding_type", - "float", - "dimensions", - dimensions - ) - ) - ); - } - } - - public void testInfer_TextEmbedding_Get_Response_clustering() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - - String responseJson = """ - {"model":"jina-clip-v2","object":"list","usage":{"total_tokens":5,"prompt_tokens":5}, - "data":[{"object":"embedding","index":0,"embedding":[0.123, -0.123]}]} - """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - String apiKey = "apiKey"; - int dimensions = 1024; String modelName = "jina-clip-v2"; + int dimensions = 1024; + String apiKey = "apiKey"; var model = JinaAIEmbeddingsModelTests.createModel( getUrl(webServer), modelName, JinaAIEmbeddingType.FLOAT, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, apiKey, - dimensions + dimensions, + TEXT_EMBEDDING, + false ); PlainActionFuture listener = new PlainActionFuture<>(); List input = List.of("abc"); @@ -1030,12 +1088,12 @@ public void testInfer_TextEmbedding_Get_Response_clustering() throws IOException input, false, new HashMap<>(), - InputType.CLUSTERING, + inputType, InferenceAction.Request.DEFAULT_TIMEOUT, listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(TEST_REQUEST_TIMEOUT); assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); @@ -1044,67 +1102,14 @@ public void testInfer_TextEmbedding_Get_Response_clustering() throws IOException assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer " + apiKey)); - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat( - requestMap, - is(Map.of("input", input, "model", modelName, "task", "separation", "embedding_type", "float", "dimensions", dimensions)) - ); - } - } - - public void testInfer_TextEmbedding_Get_Response_NullInputType() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - - String responseJson = """ - { - "model": "jina-clip-v2", - "object": "list", - "usage": { - "total_tokens": 5, - "prompt_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - String apiKey = "apiKey"; - int dimensions = 1024; - String modelName = "jina-clip-v2"; - var model = JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - modelName, - JinaAIEmbeddingType.FLOAT, - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - apiKey, - dimensions + Map expectedRequestMap = new HashMap<>( + Map.of("input", input, "model", modelName, "embedding_type", "float", "dimensions", dimensions) ); - PlainActionFuture listener = new PlainActionFuture<>(); - List input = List.of("abc"); - service.infer(model, null, null, null, input, false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener); - - var result = listener.actionGet(TIMEOUT); - - assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); - - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer " + apiKey)); - + if (expectedJinaTask != null) { + expectedRequestMap.put("task", expectedJinaTask); + } var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat(requestMap, is(Map.of("input", input, "model", modelName, "embedding_type", "float", "dimensions", dimensions))); + assertThat(requestMap, is(expectedRequestMap)); } } @@ -1150,7 +1155,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(TEST_REQUEST_TIMEOUT); var resultAsMap = result.asMap(); assertThat( resultAsMap, @@ -1232,7 +1237,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(TEST_REQUEST_TIMEOUT); var resultAsMap = result.asMap(); assertThat( resultAsMap, @@ -1326,7 +1331,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(TEST_REQUEST_TIMEOUT); var resultAsMap = result.asMap(); assertThat( resultAsMap, @@ -1406,7 +1411,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept listener ); - var result = listener.actionGet(TIMEOUT); + var result = listener.actionGet(TEST_REQUEST_TIMEOUT); var resultAsMap = result.asMap(); assertThat( resultAsMap, @@ -1445,169 +1450,133 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept ); } - } - public void testInfer_TextEmbedding_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() - throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - - String responseJson = """ - { - "model": "jina-clip-v2", - "object": "list", - "usage": { - "total_tokens": 5, - "prompt_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - String apiKey = "apiKey"; - int dimensions = 1024; - String modelName = "jina-clip-v2"; - - var model = JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - modelName, - JinaAIEmbeddingType.FLOAT, - new JinaAIEmbeddingsTaskSettings(null, null), - apiKey, - dimensions - ); - PlainActionFuture listener = new PlainActionFuture<>(); - List input = List.of("abc"); - service.infer( - model, - null, - null, - null, - input, - false, - new HashMap<>(), - InputType.UNSPECIFIED, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + public void test_TextEmbeddingModel_ChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException { + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "jina-clip-v2", + new JinaAIEmbeddingsTaskSettings(InputType.INGEST), + createRandomChunkingSettings(), + "secret", + TEXT_EMBEDDING + ); - var result = listener.actionGet(TIMEOUT); + test_embedding_chunkedInfer_batchesCalls(model, model.getTaskSettings().getLateChunking(), false); + } - assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer " + apiKey)); + public void test_TextEmbeddingModel_ChunkedInfer_ChunkingSettingsNotSet() throws IOException { + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "jina-clip-v2", + new JinaAIEmbeddingsTaskSettings(InputType.INGEST), + "secret", + TEXT_EMBEDDING + ); - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); - assertThat(requestMap, is(Map.of("input", input, "model", modelName, "embedding_type", "float", "dimensions", dimensions))); - } + test_embedding_chunkedInfer_batchesCalls(model, model.getTaskSettings().getLateChunking(), false); } - public void test_Embedding_ChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException { + public void test_TextEmbeddingModel_ChunkedInfer_LateChunkingEnabled() throws IOException { var model = JinaAIEmbeddingsModelTests.createModel( getUrl(webServer), "jina-clip-v2", - new JinaAIEmbeddingsTaskSettings(InputType.INGEST), - createRandomChunkingSettings(), - "secret" + new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), + "secret", + TEXT_EMBEDDING ); - test_Embedding_ChunkedInfer_BatchesCalls(model); + test_embedding_chunkedInfer_batchesCalls(model, model.getTaskSettings().getLateChunking(), false); } - public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOException { + public void test_TextEmbeddingModel_ChunkedInfer_LateChunkingDisabled() throws IOException { var model = JinaAIEmbeddingsModelTests.createModel( getUrl(webServer), "jina-clip-v2", - new JinaAIEmbeddingsTaskSettings(InputType.INGEST), - "secret" + new JinaAIEmbeddingsTaskSettings(InputType.INGEST, false), + "secret", + TEXT_EMBEDDING ); - test_Embedding_ChunkedInfer_BatchesCalls(model); + test_embedding_chunkedInfer_batchesCalls(model, model.getTaskSettings().getLateChunking(), false); } - public void test_Embedding_ChunkedInfer_LateChunkingEnabled() throws IOException { + public void test_embeddingModel_chunkedInfer_batchesCallsWhenLateChunkingEnabled() throws IOException { var model = JinaAIEmbeddingsModelTests.createModel( getUrl(webServer), "jina-clip-v2", new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), - "secret" + "secret", + EMBEDDING ); - test_Embedding_ChunkedInfer_BatchesCalls(model); + test_embedding_chunkedInfer_batchesCalls(model, model.getTaskSettings().getLateChunking(), false); } - public void test_Embedding_ChunkedInfer_LateChunkingDisabled() throws IOException { + public void test_embeddingModel_chunkedInfer_batchesCallsWhenLateChunkingEnabled_inputContainsNonTextInput() throws IOException { var model = JinaAIEmbeddingsModelTests.createModel( getUrl(webServer), "jina-clip-v2", - new JinaAIEmbeddingsTaskSettings(InputType.INGEST, false), - "secret" + new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), + "secret", + EMBEDDING ); - test_Embedding_ChunkedInfer_BatchesCalls(model); + test_embedding_chunkedInfer_batchesCalls(model, false, true); } - public void test_Embedding_ChunkedInfer_noInputs() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var model = JinaAIEmbeddingsModelTests.createModel(getUrl(webServer), "jina-clip-v2", "secret"); - - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - PlainActionFuture> listener = new PlainActionFuture<>(); - service.chunkedInfer( - model, - null, - List.of(), - new HashMap<>(), - InputType.UNSPECIFIED, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + public void test_embeddingModel_chunkedInfer_batchesCallsWhenLateChunkingDisabled_inputContainsNonTextInput() throws IOException { + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "jina-clip-v2", + new JinaAIEmbeddingsTaskSettings(InputType.INGEST, false), + "secret", + EMBEDDING + ); - var results = listener.actionGet(TIMEOUT); - assertThat(results, empty()); - assertThat(webServer.requests(), empty()); - } + test_embedding_chunkedInfer_batchesCalls(model, false, true); } - private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel model) throws IOException { + private void test_embedding_chunkedInfer_batchesCalls( + JinaAIEmbeddingsModel model, + Boolean expectMultipleResponses, + boolean nonTextInput + ) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { - queueResponsesForChunkedInfer(model.getTaskSettings().getLateChunking()); + queueResponsesForChunkedInfer(expectMultipleResponses); PlainActionFuture> listener = new PlainActionFuture<>(); - // 2 input + // 2 inputs + String firstInput = "first_input"; + String secondInput = "second_input"; + List input; + if (nonTextInput) { + input = List.of( + new ChunkInferenceInput(firstInput), + new ChunkInferenceInput(new InferenceStringGroup(new InferenceString(IMAGE, BASE64, secondInput)), null) + ); + } else { + input = List.of(new ChunkInferenceInput(firstInput), new ChunkInferenceInput(secondInput)); + } service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + input, new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, listener ); - var results = listener.actionGet(TIMEOUT); + var results = listener.actionGet(TEST_REQUEST_TIMEOUT); assertThat(results, hasSize(2)); { assertThat(results.getFirst(), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.getFirst(); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().getFirst().offset()); + assertEquals(new ChunkedInference.TextOffset(0, firstInput.length()), floatResult.chunks().getFirst().offset()); assertThat(floatResult.chunks().getFirst().embedding(), Matchers.instanceOf(EmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, @@ -1619,7 +1588,7 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().getFirst().offset()); + assertEquals(new ChunkedInference.TextOffset(0, secondInput.length()), floatResult.chunks().getFirst().offset()); assertThat(floatResult.chunks().getFirst().embedding(), Matchers.instanceOf(EmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.223f, -0.223f }, @@ -1707,6 +1676,300 @@ private void queueResponsesForChunkedInfer(Boolean lateChunking) { } } + public void test_ChunkedInfer_noInputs() throws IOException { + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "jina-clip-v2", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "secret", + randomFrom(TEXT_EMBEDDING, EMBEDDING) + ); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TEST_REQUEST_TIMEOUT); + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } + + public void testEmbeddingInfer_returnsError_withNonJinaModel() throws IOException { + String modelName = "model_id"; + String serviceName = "service_name"; + var mockModel = getInvalidModel(modelName, serviceName); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.embeddingInfer( + mockModel, + new EmbeddingRequest(List.of(), InputType.UNSPECIFIED, Map.of()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat( + thrownException.getMessage(), + is( + Strings.format( + "The internal model was invalid, please delete the service [%s] with id [%s] and add it again.", + serviceName, + modelName + ) + ) + ); + assertThat(thrownException.status(), is(RestStatus.INTERNAL_SERVER_ERROR)); + } + } + + public void testEmbeddingInfer_returnsError_withRerankModel() throws IOException { + var model = JinaAIRerankModelTests.createModel("modelName"); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.embeddingInfer( + model, + new EmbeddingRequest(List.of(), InputType.UNSPECIFIED, Map.of()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [jinaai] with id [id] and add it again.") + ); + assertThat(thrownException.status(), is(RestStatus.INTERNAL_SERVER_ERROR)); + } + } + + public void testEmbeddingInfer_returnsError_nonMultimodalModel_withNonTextInput() throws IOException { + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "modelName", + JinaAIEmbeddingType.FLOAT, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "apiKey", + 128, + EMBEDDING, + false + ); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture listener = new PlainActionFuture<>(); + var inputs = List.of( + new InferenceStringGroup("first_input"), + new InferenceStringGroup(new InferenceString(IMAGE, BASE64, "second_input")) + ); + service.embeddingInfer( + model, + new EmbeddingRequest(inputs, InputType.UNSPECIFIED, Map.of()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat(thrownException.getMessage(), is("Non-text input provided for text-only model")); + assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST)); + } + } + + public void testEmbeddingInfer_returnsError_multipleItemsInContentObject() throws IOException { + var model = JinaAIEmbeddingsModelTests.createEmbeddingModel(getUrl(webServer), "modelName", "apiKey"); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture listener = new PlainActionFuture<>(); + + var listSize = randomIntBetween(1, 10); + var inputs = new ArrayList(listSize); + for (int i = 0; i < listSize; ++i) { + inputs.add(new InferenceStringGroup("a_string")); + } + + // Add an InferenceStringGroup with multiple InferenceStrings at a random point in the input list + var indexToAdd = randomIntBetween(0, inputs.size() - 1); + var multipleInferenceStrings = new InferenceStringGroup( + List.of( + new InferenceString(TEXT, InferenceString.DataFormat.TEXT, "first_input"), + new InferenceString(IMAGE, BASE64, "second_input") + ) + ); + inputs.add(indexToAdd, multipleInferenceStrings); + service.embeddingInfer( + model, + new EmbeddingRequest(inputs, InputType.UNSPECIFIED, Map.of()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat( + thrownException.getMessage(), + is( + Strings.format( + "Field [content] must contain a single item for [jinaai] service. " + + "[content] object with multiple items found at $.input.content[%d]", + indexToAdd + ) + ) + ); + assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST)); + } + } + + public void testEmbeddingInfer_UnauthorisedResponse() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + + String responseJson = """ + { + "detail": "Unauthorized" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var model = JinaAIEmbeddingsModelTests.createEmbeddingModel(getUrl(webServer), "model", "secret"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.embeddingInfer( + model, + new EmbeddingRequest(List.of(), InputType.UNSPECIFIED, Map.of()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + assertThat(error.getMessage(), containsString("Error message: [Unauthorized]")); + assertThat(webServer.requests(), hasSize(1)); + } + } + + public void testEmbeddingInfer_Ingest() throws IOException { + testEmbeddingInfer(randomFrom(InputType.INGEST, InputType.INTERNAL_INGEST), "retrieval.passage"); + } + + public void testEmbeddingInfer_Search() throws IOException { + testEmbeddingInfer(randomFrom(InputType.SEARCH, InputType.INTERNAL_SEARCH), "retrieval.query"); + } + + public void testEmbeddingInfer_clustering() throws IOException { + testEmbeddingInfer(InputType.CLUSTERING, "separation"); + } + + public void testEmbeddingInfer_classification() throws IOException { + testEmbeddingInfer(InputType.CLASSIFICATION, "classification"); + } + + public void testEmbeddingInfer_nullInputType() throws IOException { + testEmbeddingInfer(null, null); + } + + public void testEmbeddingInfer_unspecifiedInputType() throws IOException { + testEmbeddingInfer(InputType.UNSPECIFIED, null); + } + + private void testEmbeddingInfer(InputType inputType, String expectedJinaTask) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + + var responseJson = """ + { + "model": "jina-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + String modelName = "jina-clip-v2"; + int dimensions = 1024; + String apiKey = "apiKey"; + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + modelName, + JinaAIEmbeddingType.FLOAT, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + apiKey, + dimensions, + EMBEDDING, + true + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + var inputs = List.of( + new InferenceStringGroup("first_input"), + new InferenceStringGroup(new InferenceString(IMAGE, BASE64, "second_input")) + ); + service.embeddingInfer( + model, + new EmbeddingRequest(inputs, inputType, Map.of()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TEST_REQUEST_TIMEOUT); + + assertEquals( + GenericDenseEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), + result.asMap() + ); + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer " + apiKey)); + + Map expectedRequestMap = new HashMap<>( + Map.of( + "input", + List.of(Map.of("text", "first_input"), Map.of("image", "second_input")), + "model", + modelName, + "embedding_type", + "float", + "dimensions", + dimensions + ) + ); + if (expectedJinaTask != null) { + expectedRequestMap.put("task", expectedJinaTask); + } + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + assertThat(requestMap, is(expectedRequestMap)); + } + } + public void testDefaultSimilarity_BinaryEmbedding() { assertEquals(SimilarityMeasure.L2_NORM, JinaAIService.defaultSimilarity(JinaAIEmbeddingType.BINARY)); assertEquals(SimilarityMeasure.L2_NORM, JinaAIService.defaultSimilarity(JinaAIEmbeddingType.BIT)); @@ -1724,7 +1987,7 @@ public void testGetConfiguration() throws Exception { { "service": "jinaai", "name": "Jina AI", - "task_types": ["text_embedding", "rerank"], + "task_types": ["text_embedding", "rerank", "embedding"], "configurations": { "api_key": { "description": "API Key for the provider you're connecting to.", @@ -1733,7 +1996,7 @@ public void testGetConfiguration() throws Exception { "sensitive": true, "updatable": true, "type": "str", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "embedding"] }, "dimensions": { "description": "The number of dimensions the resulting embeddings should have. For more information refer to https://api.jina.ai/docs#tag/embeddings/operation/create_embedding_v1_embeddings_post.", @@ -1742,7 +2005,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "embedding"] }, "embedding_type": { "description": "The type of embedding to return. One of [float, bit, binary]. bit and binary are equivalent and are encoded as bytes with signed int8 precision.", @@ -1752,7 +2015,7 @@ public void testGetConfiguration() throws Exception { "updatable": false, "default_value": "float", "type": "str", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "embedding"] }, "similarity": { "description": "The similarity measure. One of [cosine, dot_product, l2_norm]. For float embeddings, the default similarity is dot_product. For bit and binary embeddings, the default similarity is l2_norm.", @@ -1761,7 +2024,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "embedding"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -1770,7 +2033,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "embedding"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -1779,7 +2042,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["text_embedding", "rerank"] + "supported_task_types": ["text_embedding", "rerank", "embedding"] } } } @@ -1828,9 +2091,7 @@ private Map getRequestConfigMap( builtServiceSettings.putAll(serviceSettings); builtServiceSettings.putAll(secretSettings); - return new HashMap<>( - Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) - ); + return new HashMap<>(Map.of(SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)); } private Map getRequestConfigMap(Map serviceSettings, Map secretSettings) { @@ -1854,4 +2115,113 @@ public InferenceService createInferenceService() { protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { assertThat(rerankingInferenceService.rerankerWindowSize("any model"), is(5500)); } + + private static void assertEmbeddingModelSettings( + Model model, + String modelName, + RateLimitSettings rateLimitSettings, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + JinaAIEmbeddingType embeddingType, + boolean multimodalModel, + JinaAIEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, + String apiKey + ) { + assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + assertCommonModelSettings( + embeddingsModel, + DEFAULT_EMBEDDING_URL, + modelName, + rateLimitSettings, + similarity, + dimensions, + dimensionsSetByUser, + chunkingSettings, + apiKey + ); + + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(maxInputTokens)); + assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(embeddingType)); + assertThat(embeddingsModel.getServiceSettings().isMultimodal(), is(multimodalModel)); + + assertThat(embeddingsModel.getTaskSettings(), is(taskSettings)); + } + + private static void assertRerankModelSettings( + Model model, + String modelName, + RateLimitSettings rateLimitSettings, + String apiKey, + JinaAIRerankTaskSettings taskSettings + ) { + assertThat(model, instanceOf(JinaAIRerankModel.class)); + + var rerankModel = (JinaAIRerankModel) model; + assertCommonModelSettings(rerankModel, DEFAULT_RERANK_URL, modelName, rateLimitSettings, null, null, null, null, apiKey); + + assertThat(rerankModel.getTaskSettings(), is(taskSettings)); + } + + private static void assertCommonModelSettings( + T model, + String url, + String modelName, + RateLimitSettings rateLimitSettings, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Boolean dimensionsSetByUser, + @Nullable ChunkingSettings chunkingSettings, + String apiKey + ) { + assertThat(model.uri().toString(), is(url)); + assertThat(model.getServiceSettings().modelId(), is(modelName)); + assertThat(model.rateLimitServiceSettings().rateLimitSettings(), is(rateLimitSettings)); + assertThat(model.getServiceSettings().similarity(), is(similarity)); + assertThat(model.getServiceSettings().dimensions(), is(dimensions)); + assertThat(model.getServiceSettings().dimensionsSetByUser(), is(dimensionsSetByUser)); + + assertThat(model.getConfigurations().getChunkingSettings(), is(chunkingSettings)); + + assertThat(model.apiKey().toString(), is(apiKey)); + } + + private static ActionListener getModelListenerForStatusException(String expectedMessage) { + return ActionListener.wrap((model) -> fail("Model parsing should have failed"), e -> { + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat(e.getMessage(), is(expectedMessage)); + }); + } + + private static void assertParsePersistedConfigWithSecretsMinimalSettings( + JinaAIService service, + Utils.PersistedConfig persistedConfig, + String modelName, + String apiKey + ) { + var model = service.parsePersistedConfigWithSecrets( + "id", + randomFrom(service.supportedTaskTypes()), + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model.getServiceSettings().modelId(), is(modelName)); + assertThat(model.apiKey().toString(), is(apiKey)); + } + + private static void assertParsePersistedConfigMinimalSettings( + JinaAIService service, + Utils.PersistedConfig persistedConfig, + String modelName + ) { + var model = service.parsePersistedConfig("id", randomFrom(service.supportedTaskTypes()), persistedConfig.config()); + + assertThat(model.getServiceSettings().modelId(), is(modelName)); + assertThat(model.apiKey().toString(), is("")); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..c67d4c158fab5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/BaseJinaAIEmbeddingsServiceSettingsTests.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.embeddings; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettingsTests; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS_SET_BY_USER; +import static org.elasticsearch.xpack.inference.services.ServiceFields.EMBEDDING_TYPE; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.updateEmbeddingDetails; +import static org.hamcrest.Matchers.sameInstance; + +public class BaseJinaAIEmbeddingsServiceSettingsTests extends ESTestCase { + + public void testUpdateEmbeddingDetails_returnsSameInstance_whenEmbeddingSizeAndSimilarityAreSame() { + var settings = randomBoolean() + ? JinaAIEmbeddingServiceSettingsTests.createRandomWithNoNullValues() + : JinaAITextEmbeddingServiceSettingsTests.createRandomWithNoNullValues(); + + var updatedSettings = updateEmbeddingDetails(settings, settings.dimensions(), settings.similarity()); + + assertThat(updatedSettings, sameInstance(settings)); + } + + /** + * Returns a map containing only the fields that are required by both {@link JinaAIEmbeddingServiceSettings} and + * {@link JinaAITextEmbeddingServiceSettings} + */ + public static Map getMapOfMinimalEmbeddingSettings(String modelName) { + return getMapOfCommonEmbeddingSettings(modelName, null, null, null, null, null, null); + } + + /** + * Returns a map containing all fields that are used by both {@link JinaAIEmbeddingServiceSettings} and + * {@link JinaAITextEmbeddingServiceSettings} + */ + public static Map getMapOfCommonEmbeddingSettings( + String modelName, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + @Nullable JinaAIEmbeddingType embeddingType, + @Nullable Integer requestsPerMinute + ) { + var map = JinaAIServiceSettingsTests.getServiceSettingsMap(modelName, requestsPerMinute); + if (similarity != null) { + map.put(SIMILARITY, similarity.toString()); + } + if (dimensions != null) { + map.put(DIMENSIONS, dimensions); + } + if (dimensionsSetByUser != null) { + map.put(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + } + if (maxInputTokens != null) { + map.put(MAX_INPUT_TOKENS, maxInputTokens); + } + if (embeddingType != null) { + map.put(EMBEDDING_TYPE, embeddingType.toString()); + } + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettingsTests.java new file mode 100644 index 0000000000000..e2bea367aa8f6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingServiceSettingsTests.java @@ -0,0 +1,391 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace; +import static org.elasticsearch.inference.EmbeddingRequest.JINA_AI_EMBEDDING_TASK_ADDED; +import static org.elasticsearch.xpack.inference.Utils.randomSimilarityMeasure; +import static org.elasticsearch.xpack.inference.services.ConfigurationParseContext.PERSISTENT; +import static org.elasticsearch.xpack.inference.services.ConfigurationParseContext.REQUEST; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS_SET_BY_USER; +import static org.elasticsearch.xpack.inference.services.ServiceFields.EMBEDDING_TYPE; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MULTIMODAL_MODEL; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettingsTests.getMapOfCommonEmbeddingSettings; +import static org.hamcrest.Matchers.is; + +public class JinaAIEmbeddingServiceSettingsTests extends AbstractBWCWireSerializationTestCase { + + public static JinaAIEmbeddingServiceSettings createRandom() { + SimilarityMeasure similarityMeasure = randomBoolean() ? null : randomSimilarityMeasure(); + Integer dimensions = randomBoolean() ? null : randomIntBetween(32, 256); + Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); + + var commonSettings = JinaAIServiceSettingsTests.createRandom(); + var embeddingType = randomBoolean() ? null : randomFrom(JinaAIEmbeddingType.values()); + var dimensionsSetByUser = randomBoolean(); + var multimodalModel = randomBoolean(); + + return new JinaAIEmbeddingServiceSettings( + commonSettings, + similarityMeasure, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser, + multimodalModel + ); + } + + public static JinaAIEmbeddingServiceSettings createRandomWithNoNullValues() { + SimilarityMeasure similarityMeasure = randomSimilarityMeasure(); + Integer dimensions = randomIntBetween(32, 256); + Integer maxInputTokens = randomIntBetween(128, 256); + + var commonSettings = JinaAIServiceSettingsTests.createRandom(); + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var dimensionsSetByUser = randomBoolean(); + var multimodalModel = randomBoolean(); + + return new JinaAIEmbeddingServiceSettings( + commonSettings, + similarityMeasure, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser, + multimodalModel + ); + } + + public void testFromMap_persistentContext_createsSettingsCorrectly() { + testFromMap(randomNonNegativeInt(), PERSISTENT); + } + + public void testFromMap_requestContext_createsSettingsCorrectly() { + testFromMap(randomNonNegativeInt(), REQUEST); + } + + public void testFromMap_requestContext_nullDimensions_createsSettingsCorrectly() { + testFromMap(null, REQUEST); + } + + private static void testFromMap(Integer dimensions, ConfigurationParseContext parseContext) { + var similarity = randomSimilarityMeasure(); + var maxInputTokens = randomNonNegativeInt(); + var model = randomAlphanumericOfLength(8); + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var requestsPerMinute = randomNonNegativeInt(); + var multimodalModel = randomBoolean(); + var settingsMap = getMapOfCommonEmbeddingSettings( + model, + similarity, + dimensions, + null, + maxInputTokens, + embeddingType, + requestsPerMinute + ); + + settingsMap.put(MULTIMODAL_MODEL, multimodalModel); + + boolean dimensionsSetByUser; + if (parseContext == REQUEST) { + dimensionsSetByUser = dimensions != null; + } else { + dimensionsSetByUser = randomBoolean(); + settingsMap.put(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + } + + var serviceSettings = JinaAIEmbeddingServiceSettings.fromMap(settingsMap, parseContext); + + assertThat( + serviceSettings, + is( + new JinaAIEmbeddingServiceSettings( + new JinaAIServiceSettings(model, new RateLimitSettings(requestsPerMinute)), + similarity, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser, + multimodalModel + ) + ) + ); + } + + public void testFromMap_onlyRequiredFields() { + var model = "model"; + var serviceSettings = JinaAIEmbeddingServiceSettings.fromMap( + new HashMap<>(Map.of(MODEL_ID, model)), + randomFrom(ConfigurationParseContext.values()) + ); + + assertThat( + serviceSettings, + is( + new JinaAIEmbeddingServiceSettings( + new JinaAIServiceSettings(model, null), + null, + null, + null, + JinaAIEmbeddingType.FLOAT, + false, + true + ) + ) + ); + } + + public void testFromMap_InvalidSimilarity_ThrowsError() { + var similarity = "by_size"; + var thrownException = expectThrows( + ValidationException.class, + () -> JinaAIEmbeddingServiceSettings.fromMap( + new HashMap<>(Map.of(MODEL_ID, "model", ServiceFields.SIMILARITY, similarity)), + randomFrom(ConfigurationParseContext.values()) + ) + ); + + assertThat( + thrownException.getMessage(), + is( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%s] received. [similarity] " + + "must be one of [cosine, dot_product, l2_norm];", + similarity + ) + ) + ); + } + + public void testFromMap_nonPositiveDimensions_throwsError() { + var dimensions = randomIntBetween(-5, 0); + var thrownException = expectThrows( + ValidationException.class, + () -> JinaAIEmbeddingServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model", DIMENSIONS, dimensions)), + randomFrom(ConfigurationParseContext.values()) + ) + ); + + assertThat( + thrownException.getMessage(), + is( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [%s] must be a positive integer;", + dimensions, + DIMENSIONS + ) + ) + ); + } + + public void testFromMap_InvalidEmbeddingType_ThrowsError() { + var embeddingType = "invalid"; + var thrownException = expectThrows( + ValidationException.class, + () -> JinaAIEmbeddingServiceSettings.fromMap( + new HashMap<>(Map.of(MODEL_ID, "model", EMBEDDING_TYPE, embeddingType)), + randomFrom(ConfigurationParseContext.values()) + ) + ); + + assertThat( + thrownException.getMessage(), + is( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%s] received. [embedding_type] " + + "must be one of [binary, bit, float];", + embeddingType + ) + ) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = new JinaAIEmbeddingServiceSettings( + new JinaAIServiceSettings("model", new RateLimitSettings(3)), + SimilarityMeasure.COSINE, + 5, + 10, + JinaAIEmbeddingType.FLOAT, + true, + true + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + assertThat(xContentResult, is(stripWhitespace(""" + { + "model_id":"model", + "rate_limit":{"requests_per_minute":3}, + "dimensions":5, + "embedding_type":"float", + "max_input_tokens":10, + "similarity":"cosine", + "multimodal_model":true, + "dimensions_set_by_user": true + }"""))); + } + + public void testUpdate() { + var settings = createRandom(); + var similarity = randomSimilarityMeasure(); + var dimensions = randomIntBetween(32, 256); + + var newSettings = settings.update(similarity, dimensions); + + var expectedSettings = new JinaAIEmbeddingServiceSettings( + settings.getCommonSettings(), + similarity, + dimensions, + settings.maxInputTokens(), + settings.getEmbeddingType(), + settings.dimensionsSetByUser(), + settings.isMultimodal() + ); + + assertThat(newSettings, is(expectedSettings)); + } + + @Override + protected Writeable.Reader instanceReader() { + return JinaAIEmbeddingServiceSettings::new; + } + + @Override + protected JinaAIEmbeddingServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected JinaAIEmbeddingServiceSettings mutateInstance(JinaAIEmbeddingServiceSettings instance) throws IOException { + var commonSettings = instance.getCommonSettings(); + var similarity = instance.similarity(); + var dimensions = instance.dimensions(); + var maxInputTokens = instance.maxInputTokens(); + var embeddingType = instance.getEmbeddingType(); + var dimensionsSetByUser = instance.dimensionsSetByUser(); + var multimodal = instance.isMultimodal(); + switch (randomInt(6)) { + case 0 -> commonSettings = randomValueOtherThan(commonSettings, JinaAIServiceSettingsTests::createRandom); + case 1 -> similarity = randomValueOtherThan(similarity, () -> randomFrom(randomSimilarityMeasure(), null)); + case 2 -> dimensions = randomValueOtherThan(dimensions, ESTestCase::randomNonNegativeIntOrNull); + case 3 -> maxInputTokens = randomValueOtherThan(maxInputTokens, () -> randomFrom(randomIntBetween(128, 256), null)); + case 4 -> embeddingType = randomValueOtherThan(embeddingType, () -> randomFrom(JinaAIEmbeddingType.values())); + case 5 -> dimensionsSetByUser = randomValueOtherThan(dimensionsSetByUser, ESTestCase::randomBoolean); + case 6 -> multimodal = randomValueOtherThan(multimodal, ESTestCase::randomBoolean); + default -> throw new AssertionError("Illegal randomisation branch"); + } + + return new JinaAIEmbeddingServiceSettings( + commonSettings, + similarity, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser, + multimodal + ); + } + + @Override + protected JinaAIEmbeddingServiceSettings mutateInstanceForVersion(JinaAIEmbeddingServiceSettings instance, TransportVersion version) { + boolean multimodalModel = instance.isMultimodal(); + boolean dimensionsSetByUser = instance.dimensionsSetByUser(); + JinaAIEmbeddingType embeddingType = instance.getEmbeddingType(); + + // default to false for multimodal if node is on a version before embedding task support was added + if (version.supports(JINA_AI_EMBEDDING_TASK_ADDED) == false) { + multimodalModel = false; + } + // default to false for dimensionsSetByUser if node is on a version before support for setting embedding dimensions was added + if (version.supports(JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED) == false) { + dimensionsSetByUser = false; + } + // default to null embedding type if node is on a version before embedding type was introduced + if (version.supports(JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED) == false) { + embeddingType = null; + } + return new JinaAIEmbeddingServiceSettings( + instance.getCommonSettings(), + instance.similarity(), + instance.dimensions(), + instance.maxInputTokens(), + embeddingType, + dimensionsSetByUser, + multimodalModel + ); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + + public static Map getServiceSettingsMap( + String modelName, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + @Nullable JinaAIEmbeddingType embeddingType, + @Nullable Integer requestsPerMinute, + @Nullable Boolean multimodalModel + ) { + var map = getMapOfCommonEmbeddingSettings( + modelName, + similarity, + dimensions, + dimensionsSetByUser, + maxInputTokens, + embeddingType, + requestsPerMinute + ); + + if (multimodalModel != null) { + map.put(MULTIMODAL_MODEL, multimodalModel); + } + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java index dd119cfcca683..9b8d5a0cfa784 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -19,6 +20,8 @@ import java.util.Map; +import static org.elasticsearch.inference.TaskType.EMBEDDING; +import static org.elasticsearch.xpack.inference.TaskTypeTests.randomEmbeddingTaskType; import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIService.VALID_INPUT_TYPE_VALUES; import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap; import static org.hamcrest.Matchers.is; @@ -27,13 +30,13 @@ public class JinaAIEmbeddingsModelTests extends ESTestCase { public void testConstructor_usesDefaultUrlWhenNull() { - var model = createModel(null, randomAlphaOfLength(10), randomAlphaOfLength(10)); + var model = createTextEmbeddingModel(null, randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThat(model.uri().toString(), is("https://api.jina.ai/v1/embeddings")); } public void testConstructor_usesUrlWhenSpecified() { String url = "some_URL"; - var model = createModel(url, randomAlphaOfLength(10), randomAlphaOfLength(10)); + var model = createTextEmbeddingModel(url, randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThat(model.uri().toString(), is(url)); } @@ -42,7 +45,8 @@ public void testOf_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEmpty() { null, "modelName", new JinaAIEmbeddingsTaskSettings(randomFrom(VALID_INPUT_TYPE_VALUES), randomBoolean()), - "api_key" + "api_key", + randomEmbeddingTaskType() ); var overriddenModel = JinaAIEmbeddingsModel.of(model, Map.of()); @@ -54,7 +58,8 @@ public void testOf_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull() { null, "modelName", new JinaAIEmbeddingsTaskSettings(randomFrom(VALID_INPUT_TYPE_VALUES), randomBoolean()), - "api_key" + "api_key", + randomEmbeddingTaskType() ); var overriddenModel = JinaAIEmbeddingsModel.of(model, null); @@ -63,7 +68,7 @@ public void testOf_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull() { public void testOf_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEqual() { JinaAIEmbeddingsTaskSettings taskSettings = new JinaAIEmbeddingsTaskSettings(randomFrom(VALID_INPUT_TYPE_VALUES), randomBoolean()); - var model = createModel(null, "modelName", taskSettings, "api_key"); + var model = createModel(null, "modelName", taskSettings, "api_key", randomEmbeddingTaskType()); var overriddenModel = JinaAIEmbeddingsModel.of( model, @@ -75,28 +80,37 @@ public void testOf_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEqual() { public void testOf_SetsInputType_FromRequestTaskSettings_IfValid_OverridingStoredTaskSettings() { String modelName = "modelName"; String apiKey = "api_key"; - var model = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), apiKey); + TaskType taskType = randomEmbeddingTaskType(); + var model = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), apiKey, taskType); var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.SEARCH)); - var expectedModel = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.SEARCH, true), apiKey); + var expectedModel = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.SEARCH, true), apiKey, taskType); assertThat(overriddenModel, is(expectedModel)); } public void testOf_SetsLateChunking_FromRequestTaskSettings() { String modelName = "modelName"; String apiKey = "api_key"; - var model = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), apiKey); + TaskType taskType = randomEmbeddingTaskType(); + var model = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.INGEST, true), apiKey, taskType); var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.INGEST, false)); - var expectedModel = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.INGEST, false), apiKey); + var expectedModel = createModel(null, modelName, new JinaAIEmbeddingsTaskSettings(InputType.INGEST, false), apiKey, taskType); assertThat(overriddenModel, is(expectedModel)); } /** - * Returns a model with empty task settings, service settings and chunking settings + * Returns a model with empty task settings, service settings and chunking settings, using the {@link TaskType#TEXT_EMBEDDING} task type */ - public static JinaAIEmbeddingsModel createModel(String url, String modelName, String apiKey) { - return createModel(url, modelName, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, apiKey); + public static JinaAIEmbeddingsModel createTextEmbeddingModel(String url, String modelName, String apiKey) { + return createModel(url, modelName, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, apiKey, TaskType.TEXT_EMBEDDING); + } + + /** + * Returns a model with empty task settings, service settings and chunking settings, using the {@link TaskType#EMBEDDING} task type + */ + public static JinaAIEmbeddingsModel createEmbeddingModel(String url, String modelName, String apiKey) { + return createModel(url, modelName, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, apiKey, EMBEDDING); } /** @@ -106,10 +120,11 @@ public static JinaAIEmbeddingsModel createModel( @Nullable String url, String modelName, JinaAIEmbeddingsTaskSettings taskSettings, - String apiKey + String apiKey, + TaskType taskType ) { - var serviceSettings = getEmbeddingServiceSettings(modelName, null, null, null, null, null, false); - return createModel(url, serviceSettings, taskSettings, null, apiKey); + var serviceSettings = getEmbeddingServiceSettings(modelName, null, null, null, null, null, false, taskType, taskType == EMBEDDING); + return createModel(url, serviceSettings, taskSettings, null, apiKey, taskType); } /** @@ -120,10 +135,11 @@ public static JinaAIEmbeddingsModel createModel( String modelName, JinaAIEmbeddingsTaskSettings taskSettings, @Nullable ChunkingSettings chunkingSettings, - String apiKey + String apiKey, + TaskType taskType ) { - var serviceSettings = getEmbeddingServiceSettings(modelName, null, null, null, null, null, false); - return createModel(url, serviceSettings, taskSettings, chunkingSettings, apiKey); + var serviceSettings = getEmbeddingServiceSettings(modelName, null, null, null, null, null, false, taskType, taskType == EMBEDDING); + return createModel(url, serviceSettings, taskSettings, chunkingSettings, apiKey, taskType); } /** @@ -135,10 +151,22 @@ public static JinaAIEmbeddingsModel createModel( @Nullable JinaAIEmbeddingType embeddingType, JinaAIEmbeddingsTaskSettings taskSettings, String apiKey, - @Nullable Integer dimensions + @Nullable Integer dimensions, + TaskType taskType, + boolean multimodalModel ) { - var serviceSettings = getEmbeddingServiceSettings(modelName, null, null, dimensions, null, embeddingType, dimensions != null); - return createModel(url, serviceSettings, taskSettings, null, apiKey); + var serviceSettings = getEmbeddingServiceSettings( + modelName, + null, + null, + dimensions, + null, + embeddingType, + dimensions != null, + taskType, + multimodalModel + ); + return createModel(url, serviceSettings, taskSettings, null, apiKey, taskType); } public static JinaAIEmbeddingsModel createModel( @@ -152,7 +180,9 @@ public static JinaAIEmbeddingsModel createModel( JinaAIEmbeddingsTaskSettings taskSettings, @Nullable ChunkingSettings chunkingSettings, String apiKey, - boolean dimensionsSetByUser + boolean dimensionsSetByUser, + TaskType taskType, + boolean multimodalModel ) { var serviceSettings = getEmbeddingServiceSettings( modelName, @@ -161,17 +191,20 @@ public static JinaAIEmbeddingsModel createModel( dimensions, maxInputTokens, embeddingType, - dimensionsSetByUser + dimensionsSetByUser, + taskType, + multimodalModel ); - return createModel(url, serviceSettings, taskSettings, chunkingSettings, apiKey); + return createModel(url, serviceSettings, taskSettings, chunkingSettings, apiKey, taskType); } public static JinaAIEmbeddingsModel createModel( @Nullable String url, - JinaAIEmbeddingsServiceSettings serviceSettings, + BaseJinaAIEmbeddingsServiceSettings serviceSettings, JinaAIEmbeddingsTaskSettings taskSettings, @Nullable ChunkingSettings chunkingSettings, - String apiKey + String apiKey, + TaskType taskType ) { return new JinaAIEmbeddingsModel( "id", @@ -179,26 +212,43 @@ public static JinaAIEmbeddingsModel createModel( taskSettings, chunkingSettings, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())), - url + url, + taskType ); } - public static JinaAIEmbeddingsServiceSettings getEmbeddingServiceSettings( + public static BaseJinaAIEmbeddingsServiceSettings getEmbeddingServiceSettings( String modelName, @Nullable RateLimitSettings rateLimitSettings, @Nullable SimilarityMeasure similarity, @Nullable Integer dimensions, @Nullable Integer maxInputTokens, @Nullable JinaAIEmbeddingType embeddingType, - boolean dimensionsSetByUser + boolean dimensionsSetByUser, + TaskType taskType, + boolean multimodalModel ) { - return new JinaAIEmbeddingsServiceSettings( - new JinaAIServiceSettings(modelName, rateLimitSettings), - similarity, - dimensions, - maxInputTokens, - embeddingType, - dimensionsSetByUser - ); + if (taskType == TaskType.TEXT_EMBEDDING) { + return new JinaAITextEmbeddingServiceSettings( + new JinaAIServiceSettings(modelName, rateLimitSettings), + similarity, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser + ); + } else if (taskType == EMBEDDING) { + return new JinaAIEmbeddingServiceSettings( + new JinaAIServiceSettings(modelName, rateLimitSettings), + similarity, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser, + multimodalModel + ); + } else { + throw new IllegalArgumentException("Invalid taskType: " + taskType); + } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettingsTests.java similarity index 59% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettingsTests.java index 423f631c13740..c1dec2ded9f3a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAITextEmbeddingServiceSettingsTests.java @@ -35,31 +35,33 @@ import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace; import static org.elasticsearch.xpack.inference.Utils.randomSimilarityMeasure; +import static org.elasticsearch.xpack.inference.services.ConfigurationParseContext.PERSISTENT; +import static org.elasticsearch.xpack.inference.services.ConfigurationParseContext.REQUEST; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS_SET_BY_USER; import static org.elasticsearch.xpack.inference.services.ServiceFields.EMBEDDING_TYPE; -import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MULTIMODAL_MODEL; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; -import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; -import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED; -import static org.elasticsearch.xpack.inference.services.settings.RateLimitSettings.REQUESTS_PER_MINUTE_FIELD; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettingsTests.getMapOfCommonEmbeddingSettings; +import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; -public class JinaAIEmbeddingsServiceSettingsTests extends AbstractBWCWireSerializationTestCase { +public class JinaAITextEmbeddingServiceSettingsTests extends AbstractBWCWireSerializationTestCase { - private static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED = TransportVersion.fromName( - "jina_ai_embedding_type_support_added" - ); - - public static JinaAIEmbeddingsServiceSettings createRandom() { - SimilarityMeasure similarityMeasure = SimilarityMeasure.DOT_PRODUCT; - Integer dimensions = 1024; + public static JinaAITextEmbeddingServiceSettings createRandom() { + SimilarityMeasure similarityMeasure = randomBoolean() ? null : randomSimilarityMeasure(); + Integer dimensions = randomBoolean() ? null : randomIntBetween(32, 256); Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); var commonSettings = JinaAIServiceSettingsTests.createRandom(); - var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var embeddingType = randomBoolean() ? null : randomFrom(JinaAIEmbeddingType.values()); var dimensionsSetByUser = randomBoolean(); - return new JinaAIEmbeddingsServiceSettings( + return new JinaAITextEmbeddingServiceSettings( commonSettings, similarityMeasure, dimensions, @@ -69,108 +71,67 @@ public static JinaAIEmbeddingsServiceSettings createRandom() { ); } - public void testFromMap_Request_CreatesSettingsCorrectly() { - var similarity = SimilarityMeasure.DOT_PRODUCT; - var dimensions = 1536; - var maxInputTokens = 512; - var model = "model"; + public static JinaAITextEmbeddingServiceSettings createRandomWithNoNullValues() { + SimilarityMeasure similarityMeasure = randomSimilarityMeasure(); + Integer dimensions = randomIntBetween(32, 256); + Integer maxInputTokens = randomIntBetween(128, 256); + + var commonSettings = JinaAIServiceSettingsTests.createRandom(); var embeddingType = randomFrom(JinaAIEmbeddingType.values()); - var requestsPerMinute = 1234; - var serviceSettings = JinaAIEmbeddingsServiceSettings.fromMap( - new HashMap<>( - Map.of( - ServiceFields.SIMILARITY, - similarity.toString(), - ServiceFields.DIMENSIONS, - dimensions, - ServiceFields.MAX_INPUT_TOKENS, - maxInputTokens, - ServiceFields.MODEL_ID, - model, - EMBEDDING_TYPE, - embeddingType.toString(), - RateLimitSettings.FIELD_NAME, - new HashMap<>(Map.of(REQUESTS_PER_MINUTE_FIELD, requestsPerMinute)) - ) - ), - ConfigurationParseContext.REQUEST - ); + var dimensionsSetByUser = randomBoolean(); - assertThat( - serviceSettings, - is( - new JinaAIEmbeddingsServiceSettings( - new JinaAIServiceSettings(model, new RateLimitSettings(requestsPerMinute)), - similarity, - dimensions, - maxInputTokens, - embeddingType, - true - ) - ) + return new JinaAITextEmbeddingServiceSettings( + commonSettings, + similarityMeasure, + dimensions, + maxInputTokens, + embeddingType, + dimensionsSetByUser ); } - public void testFromMap_Request_DimensionsSetByUser_IsFalse_WhenDimensionsAreNotPresent() { - var url = "https://www.abc.com"; - var model = "model"; - var serviceSettings = JinaAIEmbeddingsServiceSettings.fromMap( - new HashMap<>(Map.of(URL, url, ServiceFields.MODEL_ID, model)), - ConfigurationParseContext.REQUEST - ); + public void testFromMap_persistentContext_createsSettingsCorrectly() { + testFromMap(randomNonNegativeInt(), PERSISTENT); + } - assertThat( - serviceSettings, - is( - new JinaAIEmbeddingsServiceSettings( - new JinaAIServiceSettings(model, null), - null, - null, - null, - JinaAIEmbeddingType.FLOAT, - false - ) - ) - ); + public void testFromMap_requestContext_createsSettingsCorrectly() { + testFromMap(randomNonNegativeInt(), REQUEST); + } + + public void testFromMap_requestContext_nullDimensions_createsSettingsCorrectly() { + testFromMap(null, REQUEST); } - public void testFromMap_Persistent_CreatesSettingsCorrectly() { - var url = "https://www.abc.com"; + private static void testFromMap(Integer dimensions, ConfigurationParseContext parseContext) { var similarity = randomSimilarityMeasure(); - var dimensions = 1536; - var maxInputTokens = 512; - var model = "model"; + var maxInputTokens = randomNonNegativeInt(); + var model = randomAlphanumericOfLength(8); var embeddingType = randomFrom(JinaAIEmbeddingType.values()); - var requestsPerMinute = 1234; - var dimensionsSetByUser = randomBoolean(); - var serviceSettings = JinaAIEmbeddingsServiceSettings.fromMap( - new HashMap<>( - Map.of( - URL, - url, - SIMILARITY, - similarity.toString(), - DIMENSIONS, - dimensions, - MAX_INPUT_TOKENS, - maxInputTokens, - ServiceFields.MODEL_ID, - model, - EMBEDDING_TYPE, - embeddingType.toString(), - RateLimitSettings.FIELD_NAME, - new HashMap<>(Map.of(REQUESTS_PER_MINUTE_FIELD, requestsPerMinute)), - ServiceFields.DIMENSIONS_SET_BY_USER, - dimensionsSetByUser - ) - ), - ConfigurationParseContext.PERSISTENT + var requestsPerMinute = randomNonNegativeInt(); + var settingsMap = getMapOfCommonEmbeddingSettings( + model, + similarity, + dimensions, + null, + maxInputTokens, + embeddingType, + requestsPerMinute ); + boolean dimensionsSetByUser; + if (parseContext == REQUEST) { + dimensionsSetByUser = dimensions != null; + } else { + dimensionsSetByUser = randomBoolean(); + settingsMap.put(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + } + + var serviceSettings = JinaAITextEmbeddingServiceSettings.fromMap(settingsMap, parseContext); + assertThat( serviceSettings, is( - new JinaAIEmbeddingsServiceSettings( + new JinaAITextEmbeddingServiceSettings( new JinaAIServiceSettings(model, new RateLimitSettings(requestsPerMinute)), similarity, dimensions, @@ -182,75 +143,45 @@ public void testFromMap_Persistent_CreatesSettingsCorrectly() { ); } - public void testFromMap_WithModelId() { - var similarity = SimilarityMeasure.DOT_PRODUCT; - var dimensions = 1536; - var maxInputTokens = 512; + public void testFromMap_onlyRequiredFields() { var model = "model"; - var serviceSettings = JinaAIEmbeddingsServiceSettings.fromMap( - new HashMap<>( - Map.of( - ServiceFields.SIMILARITY, - similarity.toString(), - DIMENSIONS, - dimensions, - MAX_INPUT_TOKENS, - maxInputTokens, - ServiceFields.MODEL_ID, - model - ) - ), - ConfigurationParseContext.REQUEST + var serviceSettings = JinaAITextEmbeddingServiceSettings.fromMap( + new HashMap<>(Map.of(MODEL_ID, model)), + randomFrom(ConfigurationParseContext.values()) ); assertThat( serviceSettings, is( - new JinaAIEmbeddingsServiceSettings( + new JinaAITextEmbeddingServiceSettings( new JinaAIServiceSettings(model, null), - similarity, - dimensions, - maxInputTokens, + null, + null, + null, JinaAIEmbeddingType.FLOAT, - true + false ) ) ); } - public void testFromMap_WithEmbeddingType() { - var similarity = SimilarityMeasure.DOT_PRODUCT; - var dimensions = 1536; - var maxInputTokens = 512; - var model = "model"; - var serviceSettings = JinaAIEmbeddingsServiceSettings.fromMap( - new HashMap<>( - Map.of( - ServiceFields.SIMILARITY, - similarity.toString(), - DIMENSIONS, - dimensions, - MAX_INPUT_TOKENS, - maxInputTokens, - ServiceFields.MODEL_ID, - model, - EMBEDDING_TYPE, - JinaAIEmbeddingType.BIT.toString() - ) - ), - ConfigurationParseContext.REQUEST + public void testFromMap_InvalidEmbeddingType_ThrowsError() { + var embeddingType = "invalid"; + var thrownException = expectThrows( + ValidationException.class, + () -> JinaAITextEmbeddingServiceSettings.fromMap( + new HashMap<>(Map.of(MODEL_ID, "model", EMBEDDING_TYPE, embeddingType)), + randomFrom(ConfigurationParseContext.values()) + ) ); assertThat( - serviceSettings, + thrownException.getMessage(), is( - new JinaAIEmbeddingsServiceSettings( - new JinaAIServiceSettings(model, null), - similarity, - dimensions, - maxInputTokens, - JinaAIEmbeddingType.BIT, - true + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%s] received. [embedding_type] " + + "must be one of [binary, bit, float];", + embeddingType ) ) ); @@ -260,7 +191,7 @@ public void testFromMap_InvalidSimilarity_ThrowsError() { var similarity = "by_size"; var thrownException = expectThrows( ValidationException.class, - () -> JinaAIEmbeddingsServiceSettings.fromMap( + () -> JinaAITextEmbeddingServiceSettings.fromMap( new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model", SIMILARITY, similarity)), ConfigurationParseContext.PERSISTENT ) @@ -279,7 +210,7 @@ public void testFromMap_nonPositiveDimensions_ThrowsError() { var dimensions = randomIntBetween(-5, 0); var thrownException = expectThrows( ValidationException.class, - () -> JinaAIEmbeddingsServiceSettings.fromMap( + () -> JinaAITextEmbeddingServiceSettings.fromMap( new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model", DIMENSIONS, dimensions)), randomFrom(ConfigurationParseContext.values()) ) @@ -297,6 +228,15 @@ public void testFromMap_nonPositiveDimensions_ThrowsError() { ); } + public void testFromMap_doesNotRemoveMultimodalModelField() { + var model = "model"; + HashMap settingsMap = new HashMap<>(Map.of(MODEL_ID, model, MULTIMODAL_MODEL, true)); + var serviceSettings = JinaAITextEmbeddingServiceSettings.fromMap(settingsMap, randomFrom(ConfigurationParseContext.values())); + + assertThat(serviceSettings.isMultimodal(), is(false)); + assertThat(settingsMap, not(anEmptyMap())); + } + public void testToXContent_WritesAllValues() throws IOException { var modelName = randomAlphanumericOfLength(10); var requestsPerMinute = randomNonNegativeInt(); @@ -305,7 +245,7 @@ public void testToXContent_WritesAllValues() throws IOException { var maxInputTokens = randomNonNegativeInt(); var embeddingType = randomFrom(JinaAIEmbeddingType.values()); var dimensionsSetByUser = false; - var serviceSettings = new JinaAIEmbeddingsServiceSettings( + var serviceSettings = new JinaAITextEmbeddingServiceSettings( new JinaAIServiceSettings(modelName, new RateLimitSettings(requestsPerMinute)), similarity, dimensions, @@ -337,7 +277,7 @@ public void testToXContentFragmentOfExposedFields_WritesAllValues() throws IOExc var maxInputTokens = randomNonNegativeInt(); var embeddingType = randomFrom(JinaAIEmbeddingType.values()); var dimensionsSetByUser = false; - var serviceSettings = new JinaAIEmbeddingsServiceSettings( + var serviceSettings = new JinaAITextEmbeddingServiceSettings( new JinaAIServiceSettings(modelName, new RateLimitSettings(requestsPerMinute)), similarity, dimensions, @@ -362,18 +302,37 @@ public void testToXContentFragmentOfExposedFields_WritesAllValues() throws IOExc }""", modelName, requestsPerMinute, dimensions, embeddingType, maxInputTokens, similarity)))); } + public void testUpdate() { + var settings = createRandom(); + var similarity = randomSimilarityMeasure(); + var dimensions = randomIntBetween(32, 256); + + var newSettings = settings.update(similarity, dimensions); + + var expectedSettings = new JinaAITextEmbeddingServiceSettings( + settings.getCommonSettings(), + similarity, + dimensions, + settings.maxInputTokens(), + settings.getEmbeddingType(), + settings.dimensionsSetByUser() + ); + + assertThat(newSettings, is(expectedSettings)); + } + @Override - protected Writeable.Reader instanceReader() { - return JinaAIEmbeddingsServiceSettings::new; + protected Writeable.Reader instanceReader() { + return JinaAITextEmbeddingServiceSettings::new; } @Override - protected JinaAIEmbeddingsServiceSettings createTestInstance() { + protected JinaAITextEmbeddingServiceSettings createTestInstance() { return createRandom(); } @Override - protected JinaAIEmbeddingsServiceSettings mutateInstance(JinaAIEmbeddingsServiceSettings instance) throws IOException { + protected JinaAITextEmbeddingServiceSettings mutateInstance(JinaAITextEmbeddingServiceSettings instance) throws IOException { var commonSettings = instance.getCommonSettings(); var similarity = instance.similarity(); var dimensions = instance.dimensions(); @@ -390,7 +349,7 @@ protected JinaAIEmbeddingsServiceSettings mutateInstance(JinaAIEmbeddingsService default -> throw new AssertionError("Illegal randomisation branch"); } - return new JinaAIEmbeddingsServiceSettings( + return new JinaAITextEmbeddingServiceSettings( commonSettings, similarity, dimensions, @@ -401,7 +360,10 @@ protected JinaAIEmbeddingsServiceSettings mutateInstance(JinaAIEmbeddingsService } @Override - protected JinaAIEmbeddingsServiceSettings mutateInstanceForVersion(JinaAIEmbeddingsServiceSettings instance, TransportVersion version) { + protected JinaAITextEmbeddingServiceSettings mutateInstanceForVersion( + JinaAITextEmbeddingServiceSettings instance, + TransportVersion version + ) { if (version.supports(JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED)) { return instance; } @@ -414,7 +376,7 @@ protected JinaAIEmbeddingsServiceSettings mutateInstanceForVersion(JinaAIEmbeddi embeddingType = null; } - return new JinaAIEmbeddingsServiceSettings( + return new JinaAITextEmbeddingServiceSettings( instance.getCommonSettings(), instance.similarity(), instance.dimensions(), @@ -432,13 +394,23 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { return new NamedWriteableRegistry(entries); } - public static Map getServiceSettingsMap(String model, @Nullable JinaAIEmbeddingType embeddingType) { - var map = new HashMap<>(JinaAIServiceSettingsTests.getServiceSettingsMap(model)); - - if (embeddingType != null) { - map.put(EMBEDDING_TYPE, embeddingType.toString()); - } - - return map; + public static Map getServiceSettingsMap( + String modelName, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + @Nullable JinaAIEmbeddingType embeddingType, + @Nullable Integer requestsPerMinute + ) { + return getMapOfCommonEmbeddingSettings( + modelName, + similarity, + dimensions, + dimensionsSetByUser, + maxInputTokens, + embeddingType, + requestsPerMinute + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java index c906e9203338d..3847f2b5b09e4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntityTests.java @@ -8,7 +8,10 @@ package org.elasticsearch.xpack.inference.services.jinaai.request; import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InferenceString; +import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; @@ -21,18 +24,21 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.elasticsearch.inference.InferenceString.DataFormat.BASE64; +import static org.elasticsearch.inference.InferenceString.DataType.IMAGE; +import static org.elasticsearch.xpack.inference.TaskTypeTests.randomEmbeddingTaskType; import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests.createModel; import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests.getEmbeddingServiceSettings; import static org.hamcrest.CoreMatchers.is; public class JinaAIEmbeddingsRequestEntityTests extends ESTestCase { - public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + public void testXContent_nonMultimodal_WritesAllFields_WhenTheyAreDefined() throws IOException { var modelName = "modelName"; var embeddingType = randomFrom(JinaAIEmbeddingType.values()); var lateChunking = randomBoolean(); var dimensions = randomNonNegativeInt(); var entity = new JinaAIEmbeddingsRequestEntity( - List.of("abc"), + List.of(new InferenceStringGroup("abc")), InputType.INTERNAL_INGEST, createModel( null, @@ -40,7 +46,9 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException embeddingType, new JinaAIEmbeddingsTaskSettings(InputType.INGEST, lateChunking), "apiKey", - dimensions + dimensions, + randomEmbeddingTaskType(), + false ) ); @@ -64,11 +72,20 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException ); } - public void testXContent_WritesOnlyLateChunkingField_WhenItIsTheOnlyOptionalFieldDefined() throws IOException { + public void testXContent_nonMultimodal_WritesOnlyLateChunkingField_WhenItIsTheOnlyOptionalFieldDefined() throws IOException { var entity = new JinaAIEmbeddingsRequestEntity( - List.of("abc"), + List.of(new InferenceStringGroup("abc")), InputType.INTERNAL_INGEST, - createModel(null, "modelName", null, new JinaAIEmbeddingsTaskSettings(null, false), "apiKey", null) + createModel( + null, + "modelName", + null, + new JinaAIEmbeddingsTaskSettings(null, false), + "apiKey", + null, + randomEmbeddingTaskType(), + false + ) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -79,15 +96,27 @@ public void testXContent_WritesOnlyLateChunkingField_WhenItIsTheOnlyOptionalFiel {"input":["abc"],"model":"modelName","embedding_type":"float","task":"retrieval.passage","late_chunking":false}""")); } - public void testXContent_WritesFalseLateChunkingField_WhenLateChunkingSetToTrueButInputExceedsWordCountLimit() throws IOException { + public void testXContent_nonMultimodal_WritesFalseLateChunkingField_WhenLateChunkingSetToTrueButInputExceedsWordCountLimit() + throws IOException { int wordCount = JinaAIEmbeddingsRequestEntity.MAX_WORD_COUNT_FOR_LATE_CHUNKING + 1; - var testInput = IntStream.range(0, wordCount / 2).mapToObj(i -> "word" + i).collect(Collectors.joining(" ")) + "."; + var testInput = new InferenceStringGroup( + IntStream.range(0, wordCount / 2).mapToObj(i -> "word" + i).collect(Collectors.joining(" ")) + "." + ); var testInputs = List.of(testInput, testInput, testInput); var entity = new JinaAIEmbeddingsRequestEntity( testInputs, InputType.INTERNAL_INGEST, - createModel(null, "modelName", null, new JinaAIEmbeddingsTaskSettings(null, true), "apiKey", null) + createModel( + null, + "modelName", + null, + new JinaAIEmbeddingsTaskSettings(null, true), + "apiKey", + null, + randomEmbeddingTaskType(), + false + ) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -96,14 +125,23 @@ public void testXContent_WritesFalseLateChunkingField_WhenLateChunkingSetToTrueB assertThat(xContentResult, is(Strings.format(""" {"input":["%s","%s","%s"],"model":"modelName","embedding_type":"float",\ - "task":"retrieval.passage","late_chunking":false}""", testInput, testInput, testInput))); + "task":"retrieval.passage","late_chunking":false}""", testInput.textValue(), testInput.textValue(), testInput.textValue()))); } - public void testXContent_WritesInputTypeField_WhenItIsDefinedOnlyInTaskSettings() throws IOException { + public void testXContent_nonMultimodal_WritesInputTypeField_WhenItIsDefinedOnlyInTaskSettings() throws IOException { var entity = new JinaAIEmbeddingsRequestEntity( - List.of("abc"), + List.of(new InferenceStringGroup("abc")), null, - createModel(null, "modelName", null, new JinaAIEmbeddingsTaskSettings(InputType.SEARCH, null), "apiKey", null) + createModel( + null, + "modelName", + null, + new JinaAIEmbeddingsTaskSettings(InputType.SEARCH, null), + "apiKey", + null, + randomEmbeddingTaskType(), + false + ) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -114,11 +152,20 @@ public void testXContent_WritesInputTypeField_WhenItIsDefinedOnlyInTaskSettings( {"input":["abc"],"model":"modelName","embedding_type":"float","task":"retrieval.query"}""")); } - public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { + public void testXContent_nonMultimodal_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { var entity = new JinaAIEmbeddingsRequestEntity( - List.of("abc"), + List.of(new InferenceStringGroup("abc")), null, - createModel(null, "modelName", null, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "apiKey", null) + createModel( + null, + "modelName", + null, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "apiKey", + null, + randomEmbeddingTaskType(), + false + ) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -129,11 +176,20 @@ public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws I {"input":["abc"],"model":"modelName","embedding_type":"float"}""")); } - public void testXContent_EmbeddingTypesBit() throws IOException { + public void testXContent_nonMultimodal_EmbeddingTypesBit() throws IOException { var entity = new JinaAIEmbeddingsRequestEntity( - List.of("abc"), + List.of(new InferenceStringGroup("abc")), InputType.CLUSTERING, - createModel(null, "model", JinaAIEmbeddingType.BIT, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "apiKey", null) + createModel( + null, + "model", + JinaAIEmbeddingType.BIT, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "apiKey", + null, + randomEmbeddingTaskType(), + false + ) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -144,11 +200,20 @@ public void testXContent_EmbeddingTypesBit() throws IOException { {"input":["abc"],"model":"model","embedding_type":"binary","task":"separation"}""")); } - public void testXContent_EmbeddingTypesBinary() throws IOException { + public void testXContent_nonMultimodal_EmbeddingTypesBinary() throws IOException { var entity = new JinaAIEmbeddingsRequestEntity( - List.of("abc"), + List.of(new InferenceStringGroup("abc")), InputType.SEARCH, - createModel(null, "model", JinaAIEmbeddingType.BINARY, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "apiKey", null) + createModel( + null, + "model", + JinaAIEmbeddingType.BINARY, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "apiKey", + null, + randomEmbeddingTaskType(), + false + ) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -159,12 +224,13 @@ public void testXContent_EmbeddingTypesBinary() throws IOException { {"input":["abc"],"model":"model","embedding_type":"binary","task":"retrieval.query"}""")); } - public void testXContent_doesNotWriteDimensions_whenDimensionsSetByUserIsFalse() throws IOException { - var serviceSettings = getEmbeddingServiceSettings("modelName", null, null, 512, null, null, false); + public void testXContent_nonMultimodal_doesNotWriteDimensions_whenDimensionsSetByUserIsFalse() throws IOException { + var taskType = randomEmbeddingTaskType(); + var serviceSettings = getEmbeddingServiceSettings("modelName", null, null, 512, null, null, false, taskType, false); var entity = new JinaAIEmbeddingsRequestEntity( - List.of("abc"), + List.of(new InferenceStringGroup("abc")), null, - createModel(null, serviceSettings, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, null, "apiKey") + createModel(null, serviceSettings, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, null, "apiKey", taskType) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -174,4 +240,105 @@ public void testXContent_doesNotWriteDimensions_whenDimensionsSetByUserIsFalse() assertThat(xContentResult, is(""" {"input":["abc"],"model":"modelName","embedding_type":"float"}""")); } + + public void testXContent_multimodal_WritesAllFields_WhenTheyAreDefined() throws IOException { + var modelName = "modelName"; + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var lateChunking = randomBoolean(); + var dimensions = randomNonNegativeInt(); + String textInput = "text input"; + String imageInput = "image input"; + var entity = new JinaAIEmbeddingsRequestEntity( + List.of(new InferenceStringGroup(textInput), new InferenceStringGroup(new InferenceString(IMAGE, imageInput))), + InputType.INTERNAL_INGEST, + createModel( + null, + modelName, + embeddingType, + new JinaAIEmbeddingsTaskSettings(InputType.INGEST, lateChunking), + "apiKey", + dimensions, + TaskType.EMBEDDING, + true + ) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat( + xContentResult, + is( + Strings.format( + """ + {"input":[{"text":"%s"},{"image":"%s"}],"model":"%s",\ + "embedding_type":"%s","task":"retrieval.passage","late_chunking":false,"dimensions":%d}""", + textInput, + imageInput, + modelName, + embeddingType.toRequestString(), + dimensions + ) + ) + ); + } + + public void testXContent_multimodal_WritesTrueLateChunkingField_WhenLateChunkingSetToTrueAndInputContainsOnlyTextInput() + throws IOException { + + String textInput1 = "text input 1"; + String textInput2 = "text input 2"; + var entity = new JinaAIEmbeddingsRequestEntity( + List.of(new InferenceStringGroup(textInput1), new InferenceStringGroup(textInput2)), + InputType.INTERNAL_INGEST, + createModel(null, "modelName", null, new JinaAIEmbeddingsTaskSettings(null, true), "apiKey", null, TaskType.EMBEDDING, true) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(Strings.format(""" + {"input":[{"text":"%s"},{"text":"%s"}],"model":"modelName","embedding_type":"float",\ + "task":"retrieval.passage","late_chunking":true}""", textInput1, textInput2))); + } + + public void testXContent_multimodal_WritesFalseLateChunkingField_WhenLateChunkingSetToTrueAndInputContainsNonTextInput() + throws IOException { + + String textInput = "text input"; + String imageInput = "image input"; + var entity = new JinaAIEmbeddingsRequestEntity( + List.of(new InferenceStringGroup(textInput), new InferenceStringGroup(List.of(new InferenceString(IMAGE, BASE64, imageInput)))), + InputType.INTERNAL_INGEST, + createModel(null, "modelName", null, new JinaAIEmbeddingsTaskSettings(null, true), "apiKey", null, TaskType.EMBEDDING, true) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(Strings.format(""" + {"input":[{"text":"%s"},{"image":"%s"}],"model":"modelName","embedding_type":"float",\ + "task":"retrieval.passage","late_chunking":false}""", textInput, imageInput))); + } + + public void testXContent_multimodal_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { + String textInput = "text input"; + String imageInput = "image input"; + String modelName = "modelName"; + var entity = new JinaAIEmbeddingsRequestEntity( + List.of(new InferenceStringGroup(textInput), new InferenceStringGroup(new InferenceString(IMAGE, imageInput))), + null, + createModel(null, modelName, null, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "apiKey", null, TaskType.EMBEDDING, true) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(Strings.format(""" + {"input":[{"text":"%s"},{"image":"%s"}],"model":"%s","embedding_type":"float"}""", textInput, imageInput, modelName))); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestTests.java index 4e8c79bf72dfe..bff4ea69eb3bf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestTests.java @@ -9,7 +9,10 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.inference.InferenceString; +import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.InputTypeTests; @@ -19,6 +22,8 @@ import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -28,7 +33,7 @@ import static org.hamcrest.Matchers.is; public class JinaAIEmbeddingsRequestTests extends ESTestCase { - public void testCreateRequest_AllOptionsDefined() throws IOException { + public void testCreateRequest_AllOptionsDefined_textEmbedding() throws IOException { var inputType = InputTypeTests.randomWithNull(); boolean lateChunking = randomBoolean(); var modelName = "modelName"; @@ -37,7 +42,7 @@ public void testCreateRequest_AllOptionsDefined() throws IOException { var input = List.of("abc"); var apiKey = "api-key"; var dimensions = 512; - var request = createRequest( + var request = createTextOnlyRequest( input, inputType, JinaAIEmbeddingsModelTests.createModel( @@ -46,7 +51,9 @@ public void testCreateRequest_AllOptionsDefined() throws IOException { embeddingType, new JinaAIEmbeddingsTaskSettings(null, lateChunking), apiKey, - dimensions + dimensions, + TaskType.TEXT_EMBEDDING, + false ) ); @@ -60,47 +67,82 @@ public void testCreateRequest_AllOptionsDefined() throws IOException { assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + apiKey)); assertThat(httpPost.getLastHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(), is(JinaAIUtils.ELASTIC_REQUEST_SOURCE)); + var expectedRequestMap = new HashMap<>( + Map.of( + "input", + input, + "model", + modelName, + "embedding_type", + embeddingType.toRequestString(), + "late_chunking", + lateChunking, + "dimensions", + dimensions + ) + ); + if (InputType.isSpecified(inputType)) { + expectedRequestMap.put("task", convertInputType(inputType)); + } + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, is(expectedRequestMap)); + } + + public void testCreateRequest_AllOptionsDefined_multimodalEmbedding() throws IOException { + var inputType = InputTypeTests.randomWithNull(); + boolean lateChunking = randomBoolean(); + var modelName = "modelName"; + var url = "url"; + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); + var input = List.of("abc"); + var apiKey = "api-key"; + var dimensions = 512; + var request = createMultimodalRequest( + input, + inputType, + JinaAIEmbeddingsModelTests.createModel( + url, + modelName, + embeddingType, + new JinaAIEmbeddingsTaskSettings(null, lateChunking), + apiKey, + dimensions, + TaskType.EMBEDDING, + true + ) + ); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is(url)); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + apiKey)); + assertThat(httpPost.getLastHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(), is(JinaAIUtils.ELASTIC_REQUEST_SOURCE)); + + var expectedRequestMap = new HashMap<>( + Map.of( + "input", + List.of(Map.of("image", input.getFirst())), + "model", + modelName, + "embedding_type", + embeddingType.toRequestString(), + "late_chunking", + false, + "dimensions", + dimensions + ) + ); if (InputType.isSpecified(inputType)) { - var convertedInputType = convertInputType(inputType); - assertThat( - requestMap, - is( - Map.of( - "input", - input, - "model", - modelName, - "embedding_type", - embeddingType.toRequestString(), - "task", - convertedInputType, - "late_chunking", - lateChunking, - "dimensions", - dimensions - ) - ) - ); - } else { - assertThat( - requestMap, - is( - Map.of( - "input", - input, - "model", - modelName, - "embedding_type", - embeddingType.toRequestString(), - "late_chunking", - lateChunking, - "dimensions", - dimensions - ) - ) - ); + expectedRequestMap.put("task", convertInputType(inputType)); } + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, is(expectedRequestMap)); } public void testCreateRequest_TaskSettingsInputType() throws IOException { @@ -111,7 +153,7 @@ public void testCreateRequest_TaskSettingsInputType() throws IOException { List input = List.of("abc"); String apiKey = "api-key"; int dimensions = 512; - var request = createRequest( + var request = createTextOnlyRequest( input, null, JinaAIEmbeddingsModelTests.createModel( @@ -120,7 +162,9 @@ public void testCreateRequest_TaskSettingsInputType() throws IOException { embeddingType, new JinaAIEmbeddingsTaskSettings(inputType, null), apiKey, - dimensions + dimensions, + TaskType.TEXT_EMBEDDING, + false ) ); @@ -134,32 +178,15 @@ public void testCreateRequest_TaskSettingsInputType() throws IOException { assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + apiKey)); assertThat(httpPost.getLastHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(), is(JinaAIUtils.ELASTIC_REQUEST_SOURCE)); - var requestMap = entityAsMap(httpPost.getEntity().getContent()); + var expectedRequestMap = new HashMap<>( + Map.of("input", input, "model", modelName, "embedding_type", embeddingType.toRequestString(), "dimensions", dimensions) + ); if (InputType.isSpecified(inputType)) { - var convertedInputType = convertInputType(inputType); - assertThat( - requestMap, - is( - Map.of( - "input", - input, - "model", - modelName, - "embedding_type", - embeddingType.toRequestString(), - "task", - convertedInputType, - "dimensions", - dimensions - ) - ) - ); - } else { - assertThat( - requestMap, - is(Map.of("input", input, "model", modelName, "embedding_type", embeddingType.toRequestString(), "dimensions", dimensions)) - ); + expectedRequestMap.put("task", convertInputType(inputType)); } + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, is(expectedRequestMap)); } public void testCreateRequest_RequestInputTypeTakesPrecedence() throws IOException { @@ -169,10 +196,16 @@ public void testCreateRequest_RequestInputTypeTakesPrecedence() throws IOExcepti var url = "url"; List input = List.of("abc"); String apiKey = "api-key"; - var request = createRequest( + var request = createTextOnlyRequest( input, requestInputType, - JinaAIEmbeddingsModelTests.createModel(url, modelName, new JinaAIEmbeddingsTaskSettings(taskSettingsInputType, null), apiKey) + JinaAIEmbeddingsModelTests.createModel( + url, + modelName, + new JinaAIEmbeddingsTaskSettings(taskSettingsInputType, null), + apiKey, + TaskType.TEXT_EMBEDDING + ) ); var httpRequest = request.createHttpRequest(); @@ -185,19 +218,36 @@ public void testCreateRequest_RequestInputTypeTakesPrecedence() throws IOExcepti assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + apiKey)); assertThat(httpPost.getLastHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(), is(JinaAIUtils.ELASTIC_REQUEST_SOURCE)); - var requestMap = entityAsMap(httpPost.getEntity().getContent()); + var expectedRequestMap = new HashMap<>(Map.of("input", input, "model", modelName, "embedding_type", "float")); if (InputType.isSpecified(requestInputType)) { - var convertedInputType = convertInputType(requestInputType); - assertThat(requestMap, is(Map.of("input", input, "model", modelName, "embedding_type", "float", "task", convertedInputType))); + expectedRequestMap.put("task", convertInputType(requestInputType)); } else if (InputType.isSpecified(taskSettingsInputType)) { - var convertedInputType = convertInputType(taskSettingsInputType); - assertThat(requestMap, is(Map.of("input", input, "model", modelName, "embedding_type", "float", "task", convertedInputType))); - } else { - assertThat(requestMap, is(Map.of("input", input, "model", modelName, "embedding_type", "float"))); + expectedRequestMap.put("task", convertInputType(taskSettingsInputType)); + } + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, is(expectedRequestMap)); + } + + public static JinaAIEmbeddingsRequest createMultimodalRequest(List inputs, InputType inputType, JinaAIEmbeddingsModel model) { + boolean isTextInput = false; + List convertedInput = new ArrayList<>(); + for (String input : inputs) { + InferenceString inferenceString; + if (isTextInput) { + inferenceString = new InferenceString(InferenceString.DataType.TEXT, InferenceString.DataFormat.TEXT, input); + } else { + inferenceString = new InferenceString(InferenceString.DataType.IMAGE, InferenceString.DataFormat.BASE64, input); + } + isTextInput = isTextInput == false; + var inferenceStringGroup = new InferenceStringGroup(inferenceString); + convertedInput.add(inferenceStringGroup); } + return new JinaAIEmbeddingsRequest(convertedInput, inputType, model); } - public static JinaAIEmbeddingsRequest createRequest(List input, InputType inputType, JinaAIEmbeddingsModel model) { - return new JinaAIEmbeddingsRequest(input, inputType, model); + public static JinaAIEmbeddingsRequest createTextOnlyRequest(List inputs, InputType inputType, JinaAIEmbeddingsModel model) { + List convertedInput = inputs.stream().map(InferenceStringGroup::new).toList(); + return new JinaAIEmbeddingsRequest(convertedInput, inputType, model); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java index a94dabb1bab63..982be73f42046 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java @@ -20,7 +20,6 @@ import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import java.io.IOException; -import java.util.HashMap; import java.util.Map; import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; @@ -70,7 +69,11 @@ protected JinaAIRerankServiceSettings mutateInstanceForVersion(JinaAIRerankServi return instance; } - public static Map getServiceSettingsMap(@Nullable String url, @Nullable String model) { - return new HashMap<>(JinaAIServiceSettingsTests.getServiceSettingsMap(model)); + public static Map getServiceSettingsMap(String model) { + return getServiceSettingsMap(model, null); + } + + public static Map getServiceSettingsMap(String model, @Nullable Integer requestsPerMinute) { + return JinaAIServiceSettingsTests.getServiceSettingsMap(model, requestsPerMinute); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettingsTests.java index 521a7e861371e..f33c18c5a39d7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettingsTests.java @@ -119,10 +119,6 @@ protected JinaAIRerankTaskSettings mutateInstance(JinaAIRerankTaskSettings insta } } - public static Map getTaskSettingsMapEmpty() { - return new HashMap<>(); - } - public static Map getTaskSettingsMap(@Nullable Integer topNDocumentsOnly, @Nullable Boolean returnDocuments) { var map = new HashMap(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java index 5dab78b37e867..86270cf1c18b5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java @@ -10,27 +10,46 @@ import org.apache.http.HttpResponse; import org.elasticsearch.common.ParsingException; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.EmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.EmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.EmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.GenericDenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.GenericDenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.jinaai.request.JinaAIEmbeddingsRequestTests; +import org.hamcrest.Matchers; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; +import static org.elasticsearch.xpack.inference.TaskTypeTests.randomEmbeddingTaskType; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType.BINARY; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType.BIT; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType.FLOAT; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; import static org.mockito.Mockito.mock; public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase { - public void testFromResponse_CreatesResultsForASingleItem() throws IOException { + public void testFromResponse_CreatesResultsForASingleItem_textEmbeddingTask() throws IOException { + testFromResponse_singleItem(TaskType.TEXT_EMBEDDING); + } + + public void testFromResponse_CreatesResultsForASingleItem_embeddingTask() throws IOException { + testFromResponse_singleItem(TaskType.EMBEDDING); + } + + private static void testFromResponse_singleItem(TaskType taskType) throws IOException { String responseJson = """ { "object": "list", @@ -53,22 +72,30 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { """; InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( + JinaAIEmbeddingsRequestTests.createTextOnlyRequest( List.of("abc"), InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") + JinaAIEmbeddingsModelTests.createModel(null, "modelName", JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "secret", taskType) ), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class)); + assertResultsType(taskType, JinaAIEmbeddingType.FLOAT, parsedResults); assertThat( - ((DenseEmbeddingFloatResults) parsedResults).embeddings(), - is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) + ((EmbeddingFloatResults) parsedResults).embeddings(), + Matchers.is(List.of(new EmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) ); } - public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { + public void testFromResponse_CreatesResultsForMultipleItems_textEmbeddingTask() throws IOException { + testFromResponse_multipleItems(TaskType.TEXT_EMBEDDING); + } + + public void testFromResponse_CreatesResultsForMultipleItems_embeddingTask() throws IOException { + testFromResponse_multipleItems(TaskType.EMBEDDING); + } + + private static void testFromResponse_multipleItems(TaskType taskType) throws IOException { String responseJson = """ { "object": "list", @@ -99,21 +126,21 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException """; InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( + JinaAIEmbeddingsRequestTests.createTextOnlyRequest( List.of("abc"), InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") + JinaAIEmbeddingsModelTests.createModel(null, "modelName", JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "secret", taskType) ), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class)); + assertResultsType(taskType, JinaAIEmbeddingType.FLOAT, parsedResults); assertThat( - ((DenseEmbeddingFloatResults) parsedResults).embeddings(), + ((EmbeddingFloatResults) parsedResults).embeddings(), is( List.of( - new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) + new EmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), + new EmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) ) ) ); @@ -143,14 +170,7 @@ public void testFromResponse_FailsWhenDataFieldIsNotPresent() { var thrownException = expectThrows( IllegalStateException.class, - () -> JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( - List.of("abc"), - InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") - ), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) - ) + () -> callFromResponse(responseJson, randomFrom(JinaAIEmbeddingType.values())) ); assertThat(thrownException.getMessage(), is("Failed to find required field [data] in JinaAI embeddings response")); @@ -180,14 +200,7 @@ public void testFromResponse_FailsWhenDataFieldNotAnArray() { var thrownException = expectThrows( ParsingException.class, - () -> JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( - List.of("abc"), - InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") - ), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) - ) + () -> callFromResponse(responseJson, randomFrom(JinaAIEmbeddingType.values())) ); assertThat( @@ -220,14 +233,7 @@ public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() { var thrownException = expectThrows( IllegalStateException.class, - () -> JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( - List.of("abc"), - InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") - ), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) - ) + () -> callFromResponse(responseJson, randomFrom(JinaAIEmbeddingType.values())) ); assertThat(thrownException.getMessage(), is("Failed to find required field [embedding] in JinaAI embeddings response")); @@ -256,14 +262,7 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAString() { var thrownException = expectThrows( ParsingException.class, - () -> JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( - List.of("abc"), - InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") - ), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) - ) + () -> callFromResponse(responseJson, randomFrom(JinaAIEmbeddingType.values())) ); assertThat( @@ -272,7 +271,7 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAString() { ); } - public void testFromResponse_SucceedsWhenEmbeddingType_IsBinary() throws IOException { + public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { String responseJson = """ { "object": "list", @@ -281,11 +280,7 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBinary() throws IOExcep "object": "embedding", "index": 0, "embedding": [ - -55, - 74, - 101, - 67, - 83 + {} ] } ], @@ -297,29 +292,18 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBinary() throws IOExcep } """; - InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( - List.of("abc"), - InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel( - null, - "modelName", - JinaAIEmbeddingType.BINARY, - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - "secret", - null - ) - ), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + var thrownException = expectThrows( + ParsingException.class, + () -> callFromResponse(responseJson, randomFrom(JinaAIEmbeddingType.values())) ); assertThat( - ((DenseEmbeddingBitResults) parsedResults).embeddings(), - is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [VALUE_NUMBER] but found [START_OBJECT]") ); } - public void testFromResponse_SucceedsWhenEmbeddingType_IsBit() throws IOException { + public void testFromResponse_withBitEmbeddingType_FailsWhenEmbeddingValueIsLargerThanByte() { String responseJson = """ { "object": "list", @@ -328,11 +312,7 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBit() throws IOExceptio "object": "embedding", "index": 0, "embedding": [ - -55, - 74, - 101, - 67, - 83 + -1024 ] } ], @@ -344,29 +324,41 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBit() throws IOExceptio } """; - InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( + var thrownException = expectThrows(IllegalArgumentException.class, () -> callFromResponse(responseJson, randomFrom(BIT, BINARY))); + + assertThat(thrownException.getMessage(), is("Value [-1024] is out of range for a byte")); + } + + private static void callFromResponse(String responseJson, JinaAIEmbeddingType embeddingType) throws IOException { + JinaAIEmbeddingsResponseEntity.fromResponse( + JinaAIEmbeddingsRequestTests.createTextOnlyRequest( List.of("abc"), InputTypeTests.randomWithNull(), JinaAIEmbeddingsModelTests.createModel( null, "modelName", - JinaAIEmbeddingType.BIT, + embeddingType, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "secret", - null + null, + randomEmbeddingTaskType(), + false ) ), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); + } - assertThat( - ((DenseEmbeddingBitResults) parsedResults).embeddings(), - is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) - ); + public void testFromResponse_SucceedsWhenEmbeddingType_IsBinary() throws IOException { + fromResponse_withNonFloatEmbeddingType(BINARY); } - public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { + public void testFromResponse_SucceedsWhenEmbeddingType_IsBit() throws IOException { + fromResponse_withNonFloatEmbeddingType(BIT); + } + + private static void fromResponse_withNonFloatEmbeddingType(JinaAIEmbeddingType embeddingType) throws IOException { + assertThat(embeddingType, oneOf(BIT, BINARY)); String responseJson = """ { "object": "list", @@ -375,7 +367,11 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { "object": "embedding", "index": 0, "embedding": [ - {} + -55, + 74, + 101, + 67, + 83 ] } ], @@ -387,21 +383,29 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { } """; - var thrownException = expectThrows( - ParsingException.class, - () -> JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( - List.of("abc"), - InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") - ), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) - ) + var taskType = randomEmbeddingTaskType(); + InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( + JinaAIEmbeddingsRequestTests.createTextOnlyRequest( + List.of("abc"), + InputTypeTests.randomWithNull(), + JinaAIEmbeddingsModelTests.createModel( + null, + "modelName", + embeddingType, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "secret", + null, + taskType, + false + ) + ), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); + assertResultsType(taskType, embeddingType, parsedResults); assertThat( - thrownException.getMessage(), - is("Failed to parse object: expecting token of type [VALUE_NUMBER] but found [START_OBJECT]") + ((EmbeddingBitResults) parsedResults).embeddings(), + is(List.of(new EmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) ); } @@ -446,24 +450,42 @@ public void testFieldsInDifferentOrderServer() throws IOException { } }"""; - DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) JinaAIEmbeddingsResponseEntity.fromResponse( - JinaAIEmbeddingsRequestTests.createRequest( + TaskType taskType = randomEmbeddingTaskType(); + InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( + JinaAIEmbeddingsRequestTests.createTextOnlyRequest( List.of("abc"), InputTypeTests.randomWithNull(), - JinaAIEmbeddingsModelTests.createModel(null, "modelName", "secret") + JinaAIEmbeddingsModelTests.createModel(null, "modelName", JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "secret", taskType) ), new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8)) ); + assertResultsType(taskType, FLOAT, parsedResults); assertThat( - parsedResults.embeddings(), + ((EmbeddingFloatResults) parsedResults).embeddings(), is( List.of( - new DenseEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), - new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), - new DenseEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) + new EmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), + new EmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), + new EmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) ) ) ); } + + private static void assertResultsType(TaskType taskType, JinaAIEmbeddingType embeddingType, InferenceServiceResults parsedResults) { + if (taskType.equals(TaskType.TEXT_EMBEDDING)) { + switch (embeddingType) { + case FLOAT -> assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class)); + case BIT, BINARY -> assertThat(parsedResults, instanceOf(DenseEmbeddingBitResults.class)); + } + } else if (taskType.equals(TaskType.EMBEDDING)) { + switch (embeddingType) { + case FLOAT -> assertThat(parsedResults, instanceOf(GenericDenseEmbeddingFloatResults.class)); + case BIT, BINARY -> assertThat(parsedResults, instanceOf(GenericDenseEmbeddingBitResults.class)); + } + } else { + throw new IllegalArgumentException("Invalid taskType: " + taskType); + } + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidatorTests.java index 7431ca06b8264..06c90f3d73575 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleEmbeddingServiceIntegrationValidatorTests.java @@ -13,8 +13,10 @@ import org.elasticsearch.inference.EmbeddingRequest; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InferenceStringGroup; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.junit.Before; @@ -22,6 +24,7 @@ import org.mockito.Mock; import java.util.List; +import java.util.Map; import static org.elasticsearch.xpack.inference.services.validation.SimpleEmbeddingServiceIntegrationValidator.TEST_IMAGE_BASE64_INPUT; import static org.elasticsearch.xpack.inference.services.validation.SimpleEmbeddingServiceIntegrationValidator.TEST_TEXT_INPUT; @@ -37,10 +40,6 @@ public class SimpleEmbeddingServiceIntegrationValidatorTests extends ESTestCase { - private static final EmbeddingRequest EXPECTED_REQUEST = new EmbeddingRequest( - List.of(TEST_TEXT_INPUT, TEST_IMAGE_BASE64_INPUT), - InputType.INTERNAL_INGEST - ); private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE; @Mock @@ -48,6 +47,8 @@ public class SimpleEmbeddingServiceIntegrationValidatorTests extends ESTestCase @Mock private Model mockModel; @Mock + private ServiceSettings mockServiceSettings; + @Mock private ActionListener mockActionListener; @Mock private InferenceServiceResults mockInferenceServiceResults; @@ -61,11 +62,14 @@ public void setup() { underTest = new SimpleEmbeddingServiceIntegrationValidator(); when(mockActionListener.delegateFailureAndWrap(any())).thenCallRealMethod(); + when(mockModel.getServiceSettings()).thenReturn(mockServiceSettings); + when(mockServiceSettings.isMultimodal()).thenReturn(randomBoolean()); } public void testValidate_ServiceThrowsException() { + var expectedRequest = getExpectedRequest(); doThrow(ElasticsearchStatusException.class).when(mockInferenceService) - .embeddingInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); + .embeddingInfer(eq(mockModel), eq(expectedRequest), eq(TIMEOUT), any()); assertThrows( ElasticsearchStatusException.class, @@ -114,11 +118,12 @@ public void testValidate_CallsListenerOnFailure_WhenServiceThrowsException() { } private void mockSuccessfulCallToService(InferenceServiceResults result) { + var expectedRequest = getExpectedRequest(); doAnswer(ans -> { ActionListener responseListener = ans.getArgument(3); responseListener.onResponse(result); return null; - }).when(mockInferenceService).embeddingInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); + }).when(mockInferenceService).embeddingInfer(eq(mockModel), eq(expectedRequest), eq(TIMEOUT), any()); underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } @@ -128,17 +133,30 @@ private void mockNullResponseFromService() { } private void mockFailureResponseFromService(Exception exception) { + var expectedRequest = getExpectedRequest(); doAnswer(ans -> { ActionListener responseListener = ans.getArgument(3); responseListener.onFailure(exception); return null; - }).when(mockInferenceService).embeddingInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); + }).when(mockInferenceService).embeddingInfer(eq(mockModel), eq(expectedRequest), eq(TIMEOUT), any()); underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } private void verifyCallToService() { - verify(mockInferenceService).embeddingInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); + var expectedRequest = getExpectedRequest(); + verify(mockModel).getServiceSettings(); + verify(mockInferenceService).embeddingInfer(eq(mockModel), eq(expectedRequest), eq(TIMEOUT), any()); verifyNoMoreInteractions(mockInferenceService, mockModel, mockActionListener, mockInferenceServiceResults); } + + private EmbeddingRequest getExpectedRequest() { + List inputs; + if (mockServiceSettings.isMultimodal()) { + inputs = List.of(TEST_TEXT_INPUT, TEST_IMAGE_BASE64_INPUT); + } else { + inputs = List.of(TEST_TEXT_INPUT); + } + return new EmbeddingRequest(inputs, InputType.INTERNAL_INGEST, Map.of()); + } }