diff --git a/docs/changelog/140323.yaml b/docs/changelog/140323.yaml
new file mode 100644
index 0000000000000..1ba19da787a22
--- /dev/null
+++ b/docs/changelog/140323.yaml
@@ -0,0 +1,5 @@
+pr: 140323
+summary: "[Inference API] Add support for embedding task to JinaAI service"
+area: Inference
+type: enhancement
+issues: []
diff --git a/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java b/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java
index 29bf2727562f5..637576ef29870 100644
--- a/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java
+++ b/server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java
@@ -9,6 +9,7 @@
package org.elasticsearch.inference;
+import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
@@ -26,9 +27,11 @@
import java.io.IOException;
import java.util.List;
+import java.util.Map;
import java.util.Objects;
import static java.util.Collections.singletonList;
+import static org.elasticsearch.inference.ModelConfigurations.TASK_SETTINGS;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
@@ -62,18 +65,25 @@
* OR
*
* "input": ["first text input", "second text input"]
- * @param inputs The list of {@link InferenceStringGroup} inputs to generate embeddings for
- * @param inputType The {@link InputType} of the request
+ *
+ * @param inputs The list of {@link InferenceStringGroup} inputs to generate embeddings for
+ * @param inputType The {@link InputType} of the request
+ * @param taskSettings The map of task settings specific to this request
*/
-public record EmbeddingRequest(List inputs, InputType inputType) implements Writeable, ToXContentFragment {
+public record EmbeddingRequest(List inputs, InputType inputType, Map taskSettings)
+ implements
+ Writeable,
+ ToXContentFragment {
- private static final String INPUT_FIELD = "input";
+ public static final TransportVersion JINA_AI_EMBEDDING_TASK_ADDED = TransportVersion.fromName("jina_ai_embedding_task_added");
+
+ public static final String INPUT_FIELD = "input";
private static final String INPUT_TYPE_FIELD = "input_type";
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
EmbeddingRequest.class.getSimpleName(),
- args -> new EmbeddingRequest((List) args[0], (InputType) args[1])
+ args -> new EmbeddingRequest((List) args[0], (InputType) args[1], (Map) args[2])
);
static {
@@ -89,31 +99,48 @@ public record EmbeddingRequest(List inputs, InputType inpu
new ParseField(INPUT_TYPE_FIELD),
ObjectParser.ValueType.STRING
);
+ PARSER.declareField(
+ optionalConstructorArg(),
+ (parser, context) -> parser.mapOrdered(),
+ new ParseField(TASK_SETTINGS),
+ ObjectParser.ValueType.OBJECT
+ );
}
public static EmbeddingRequest of(List contents) {
- return new EmbeddingRequest(contents, null);
+ return new EmbeddingRequest(contents, null, null);
}
- public EmbeddingRequest(List inputs, @Nullable InputType inputType) {
+ public EmbeddingRequest(List inputs, @Nullable InputType inputType, @Nullable Map taskSettings) {
this.inputs = inputs;
this.inputType = Objects.requireNonNullElse(inputType, InputType.UNSPECIFIED);
+ this.taskSettings = Objects.requireNonNullElse(taskSettings, Map.of());
}
public EmbeddingRequest(StreamInput in) throws IOException {
- this(in.readCollectionAsImmutableList(InferenceStringGroup::new), in.readEnum(InputType.class));
+ this(
+ in.readCollectionAsImmutableList(InferenceStringGroup::new),
+ in.readEnum(InputType.class),
+ in.getTransportVersion().supports(JINA_AI_EMBEDDING_TASK_ADDED) ? in.readGenericMap() : Map.of()
+ );
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(inputs);
out.writeEnum(inputType);
+ if (out.getTransportVersion().supports(JINA_AI_EMBEDDING_TASK_ADDED)) {
+ out.writeGenericMap(taskSettings);
+ }
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(INPUT_FIELD, inputs);
builder.field(INPUT_TYPE_FIELD, inputType);
+ if (taskSettings.isEmpty() == false) {
+ builder.field(TASK_SETTINGS, taskSettings);
+ }
return builder;
}
diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java b/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java
index 986503d059120..9db304e58a987 100644
--- a/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java
+++ b/server/src/main/java/org/elasticsearch/inference/InferenceStringGroup.java
@@ -22,6 +22,7 @@
import java.io.IOException;
import java.util.List;
+import java.util.Objects;
import static java.util.Collections.singletonList;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
@@ -37,20 +38,31 @@
* ]
* }
*
- * @param inferenceStrings the list of {@link InferenceString} which should result in generating a single embedding vector
*/
-public record InferenceStringGroup(List inferenceStrings) implements Writeable, ToXContentObject {
- private static final String CONTENT_FIELD = "content";
+public final class InferenceStringGroup implements Writeable, ToXContentObject {
+ public static final String CONTENT_FIELD = "content";
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
InferenceStringGroup.class.getSimpleName(),
args -> new InferenceStringGroup((List) args[0])
);
+
static {
PARSER.declareObjectArray(constructorArg(), InferenceString.PARSER::apply, new ParseField(CONTENT_FIELD));
}
+ private final List inferenceStrings;
+ private final boolean containsNonTextEntry;
+
+ /**
+ * @param inferenceStrings the list of {@link InferenceString} which should result in generating a single embedding vector
+ */
+ public InferenceStringGroup(List inferenceStrings) {
+ this.inferenceStrings = inferenceStrings;
+ containsNonTextEntry = inferenceStrings.stream().anyMatch(s -> s.isText() == false);
+ }
+
public InferenceStringGroup(StreamInput in) throws IOException {
this(in.readCollectionAsImmutableList(InferenceString::new));
}
@@ -64,6 +76,18 @@ public InferenceStringGroup(String input) {
this(singletonList(new InferenceString(DataType.TEXT, input)));
}
+ public List inferenceStrings() {
+ return inferenceStrings;
+ }
+
+ public boolean containsNonTextEntry() {
+ return containsNonTextEntry;
+ }
+
+ public boolean containsMultipleInferenceStrings() {
+ return inferenceStrings.size() > 1;
+ }
+
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(inferenceStrings);
@@ -81,7 +105,7 @@ public static InferenceStringGroup parse(XContentParser parser) throws IOExcepti
var token = parser.currentToken();
if (token == XContentParser.Token.VALUE_STRING) {
// Create content object from String
- return new InferenceStringGroup(singletonList(new InferenceString(DataType.TEXT, parser.text())));
+ return new InferenceStringGroup(parser.text());
} else if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) {
// Create content object from InferenceString(s)
return InferenceStringGroup.PARSER.apply(parser, null);
@@ -134,4 +158,51 @@ public static List toInferenceStringList(List toStringList(List inferenceStringGroups) {
return InferenceString.toStringList(toInferenceStringList(inferenceStringGroups));
}
+
+ /**
+ * Method used to determine if a list of {@link InferenceStringGroup} contains any {@link InferenceString} that represent a non-text
+ * value
+ *
+ * @param inferenceStringGroups the list of {@link InferenceStringGroup} to check
+ * @return true if the input list contains any non-text values, false otherwise
+ */
+ public static boolean containsNonTextEntry(List inferenceStringGroups) {
+ return inferenceStringGroups.stream().anyMatch(InferenceStringGroup::containsNonTextEntry);
+ }
+
+ /**
+ * Method used to determine if a list of {@link InferenceStringGroup} contains any with more than one {@link InferenceString} in them
+ *
+ * @param inferenceStringGroups the list of {@link InferenceStringGroup} to check
+ * @return the index of the first {@link InferenceStringGroup} found to contain more than one {@link InferenceString}, or null if no
+ * elements in the list contain more than one {@link InferenceString}
+ */
+ public static Integer indexContainingMultipleInferenceStrings(List inferenceStringGroups) {
+ for (int i = 0; i < inferenceStringGroups.size(); i++) {
+ InferenceStringGroup inferenceStringGroup = inferenceStringGroups.get(i);
+ if (inferenceStringGroup.containsMultipleInferenceStrings()) {
+ return i;
+ }
+ }
+ return null;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == this) return true;
+ if (obj == null || obj.getClass() != this.getClass()) return false;
+ var that = (InferenceStringGroup) obj;
+ return Objects.equals(this.inferenceStrings, that.inferenceStrings);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(inferenceStrings);
+ }
+
+ @Override
+ public String toString() {
+ return "InferenceStringGroup[" + "inferenceStrings=" + inferenceStrings + ']';
+ }
+
}
diff --git a/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java b/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java
index f97239f8df8df..81348f65e4847 100644
--- a/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java
+++ b/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java
@@ -55,6 +55,10 @@ default DenseVectorFieldMapper.ElementType elementType() {
return null;
}
+ default boolean isMultimodal() {
+ return false;
+ }
+
/**
* The model to use in the inference endpoint (e.g. text-embedding-ada-002). Sometimes the model is not defined in the service
* settings. This can happen for external providers (e.g. hugging face, azure ai studio) where the provider requires that the model
diff --git a/server/src/main/resources/transport/definitions/referable/jina_ai_embedding_task_added.csv b/server/src/main/resources/transport/definitions/referable/jina_ai_embedding_task_added.csv
new file mode 100644
index 0000000000000..9f242502d33af
--- /dev/null
+++ b/server/src/main/resources/transport/definitions/referable/jina_ai_embedding_task_added.csv
@@ -0,0 +1 @@
+9259000
diff --git a/server/src/main/resources/transport/upper_bounds/9.4.csv b/server/src/main/resources/transport/upper_bounds/9.4.csv
index 77cff68200459..008c2f07ab390 100644
--- a/server/src/main/resources/transport/upper_bounds/9.4.csv
+++ b/server/src/main/resources/transport/upper_bounds/9.4.csv
@@ -1 +1 @@
-esql_vsr_converters_used,9258000
+jina_ai_embedding_task_added,9259000
diff --git a/server/src/test/java/org/elasticsearch/inference/EmbeddingRequestTests.java b/server/src/test/java/org/elasticsearch/inference/EmbeddingRequestTests.java
index cbf78186ee142..1b120d0d19b21 100644
--- a/server/src/test/java/org/elasticsearch/inference/EmbeddingRequestTests.java
+++ b/server/src/test/java/org/elasticsearch/inference/EmbeddingRequestTests.java
@@ -21,7 +21,10 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
+import java.util.Map;
+import static org.elasticsearch.inference.EmbeddingRequest.JINA_AI_EMBEDDING_TASK_ADDED;
+import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.is;
public class EmbeddingRequestTests extends AbstractBWCSerializationTestCase {
@@ -40,6 +43,7 @@ public void testParser_withSingleString() throws IOException {
);
assertThat(request.inputs(), is(expectedInputs));
assertThat(request.inputType(), is(InputType.SEARCH));
+ assertThat(request.taskSettings(), anEmptyMap());
}
}
@@ -60,6 +64,7 @@ public void testParser_withSingleContentObject() throws IOException {
);
assertThat(request.inputs(), is(expectedInputs));
assertThat(request.inputType(), is(InputType.SEARCH));
+ assertThat(request.taskSettings(), anEmptyMap());
}
}
@@ -78,6 +83,7 @@ public void testParser_withStringArray() throws IOException {
);
assertThat(request.inputs(), is(expectedInputs));
assertThat(request.inputType(), is(InputType.SEARCH));
+ assertThat(request.taskSettings(), anEmptyMap());
}
}
@@ -106,6 +112,7 @@ public void testParser_withSingleContentObjectWithMultipleEntries() throws IOExc
);
assertThat(request.inputs(), is(expectedInputs));
assertThat(request.inputType(), is(InputType.SEARCH));
+ assertThat(request.taskSettings(), anEmptyMap());
}
}
@@ -140,6 +147,7 @@ public void testParser_withMultipleContentObjects() throws IOException {
);
assertThat(request.inputs(), is(expectedInputs));
assertThat(request.inputType(), is(InputType.SEARCH));
+ assertThat(request.taskSettings(), anEmptyMap());
}
}
@@ -173,6 +181,7 @@ public void testParser_withUnspecifiedFormats_usesDefaults() throws IOException
);
assertThat(request.inputs(), is(expectedInputs));
assertThat(request.inputType(), is(InputType.SEARCH));
+ assertThat(request.taskSettings(), anEmptyMap());
}
}
@@ -189,6 +198,46 @@ public void testParser_withNoInputType() throws IOException {
);
assertThat(request.inputs(), is(expectedInputs));
assertThat(request.inputType(), is(InputType.UNSPECIFIED));
+ assertThat(request.taskSettings(), anEmptyMap());
+ }
+ }
+
+ public void testParser_withTaskSettings() throws IOException {
+ var requestJson = """
+ {
+ "input": "some text input",
+ "task_settings": {
+ "field_one": "value_one",
+ "field_two": 123
+ }
+ }
+ """;
+ try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) {
+ var request = EmbeddingRequest.PARSER.apply(parser, null);
+ var expectedInputs = List.of(
+ new InferenceStringGroup(List.of(new InferenceString(DataType.TEXT, DataFormat.TEXT, "some text input")))
+ );
+ assertThat(request.inputs(), is(expectedInputs));
+ assertThat(request.inputType(), is(InputType.UNSPECIFIED));
+ assertThat(request.taskSettings(), is(Map.of("field_one", "value_one", "field_two", 123)));
+ }
+ }
+
+ public void testParser_withEmptyTaskSettings() throws IOException {
+ var requestJson = """
+ {
+ "input": "some text input",
+ "task_settings": {}
+ }
+ """;
+ try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) {
+ var request = EmbeddingRequest.PARSER.apply(parser, null);
+ var expectedInputs = List.of(
+ new InferenceStringGroup(List.of(new InferenceString(DataType.TEXT, DataFormat.TEXT, "some text input")))
+ );
+ assertThat(request.inputs(), is(expectedInputs));
+ assertThat(request.inputType(), is(InputType.UNSPECIFIED));
+ assertThat(request.taskSettings(), anEmptyMap());
}
}
@@ -203,7 +252,11 @@ protected EmbeddingRequest createTestInstance() {
}
public static EmbeddingRequest createRandom() {
- return new EmbeddingRequest(randomEmbeddingContents(), randomFrom(InputType.values()));
+ return new EmbeddingRequest(
+ randomEmbeddingContents(),
+ randomFrom(InputType.values()),
+ Map.of(randomAlphanumericOfLength(8), randomAlphanumericOfLength(8))
+ );
}
private static List randomEmbeddingContents() {
@@ -216,21 +269,27 @@ private static List randomEmbeddingContents() {
@Override
protected EmbeddingRequest mutateInstance(EmbeddingRequest instance) throws IOException {
- if (randomBoolean()) {
- var embeddingContents = instance.inputs();
- return new EmbeddingRequest(
- randomValueOtherThan(embeddingContents, EmbeddingRequestTests::randomEmbeddingContents),
- instance.inputType()
+ var embeddingContents = instance.inputs();
+ var inputType = instance.inputType();
+ var taskSettings = instance.taskSettings();
+ switch (randomInt(2)) {
+ case 0 -> embeddingContents = randomValueOtherThan(embeddingContents, EmbeddingRequestTests::randomEmbeddingContents);
+ case 1 -> inputType = randomValueOtherThan(inputType, () -> randomFrom(InputType.values()));
+ case 2 -> taskSettings = randomValueOtherThan(
+ taskSettings,
+ () -> Map.of(randomAlphanumericOfLength(8), randomAlphanumericOfLength(8))
);
- } else {
- InputType inputType = instance.inputType();
- return new EmbeddingRequest(instance.inputs(), randomValueOtherThan(inputType, () -> randomFrom(InputType.values())));
}
+ return new EmbeddingRequest(embeddingContents, inputType, taskSettings);
}
@Override
protected EmbeddingRequest mutateInstanceForVersion(EmbeddingRequest instance, TransportVersion version) {
- return instance;
+ if (version.supports(JINA_AI_EMBEDDING_TASK_ADDED)) {
+ return instance;
+ } else {
+ return new EmbeddingRequest(instance.inputs(), instance.inputType(), Map.of());
+ }
}
@Override
diff --git a/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java b/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java
index ed1f7db68b497..c873a21d936a1 100644
--- a/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java
+++ b/server/src/test/java/org/elasticsearch/inference/InferenceStringGroupTests.java
@@ -14,29 +14,63 @@
import org.elasticsearch.inference.InferenceString.DataFormat;
import org.elasticsearch.inference.InferenceString.DataType;
import org.elasticsearch.test.AbstractBWCSerializationTestCase;
+import org.elasticsearch.xcontent.XContentParseException;
import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.json.JsonXContent;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
+import static org.elasticsearch.inference.InferenceString.DataType.TEXT;
+import static org.elasticsearch.inference.InferenceStringGroup.containsNonTextEntry;
+import static org.elasticsearch.inference.InferenceStringGroup.indexContainingMultipleInferenceStrings;
import static org.elasticsearch.inference.InferenceStringGroup.toInferenceStringList;
import static org.elasticsearch.inference.InferenceStringGroup.toStringList;
import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.nullValue;
public class InferenceStringGroupTests extends AbstractBWCSerializationTestCase {
public void testStringConstructor() {
String stringValue = "a string";
var input = new InferenceStringGroup(stringValue);
- assertThat(input.inferenceStrings(), contains(new InferenceString(DataType.TEXT, DataFormat.TEXT, stringValue)));
+ assertThat(input.inferenceStrings(), contains(new InferenceString(TEXT, DataFormat.TEXT, stringValue)));
+ assertThat(input.containsNonTextEntry(), is(false));
+ assertThat(input.containsMultipleInferenceStrings(), is(false));
}
- public void testSingleArgumentConstructor() {
+ public void testSingleInferenceStringConstructor() {
InferenceString inferenceString = new InferenceString(DataType.IMAGE, DataFormat.BASE64, "a string");
var input = new InferenceStringGroup(inferenceString);
assertThat(input.inferenceStrings(), contains(inferenceString));
+ assertThat(input.containsNonTextEntry(), is(true));
+ assertThat(input.containsMultipleInferenceStrings(), is(false));
+ }
+
+ public void testInferenceStringListConstructor() {
+ InferenceString inferenceString1 = new InferenceString(DataType.IMAGE, DataFormat.BASE64, "a string");
+ InferenceString inferenceString2 = new InferenceString(TEXT, DataFormat.TEXT, "a string");
+ var input = new InferenceStringGroup(List.of(inferenceString1, inferenceString2));
+ assertThat(input.inferenceStrings(), contains(inferenceString1, inferenceString2));
+ assertThat(input.containsNonTextEntry(), is(true));
+ assertThat(input.containsMultipleInferenceStrings(), is(true));
+ }
+
+ public void testParser_withEmptyContentObject_throws() throws IOException {
+ var requestJson = """
+ {
+ "content": {}
+ }
+ """;
+ try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) {
+ // Need to call nextToken() so that the parser is at the correct element
+ parser.nextToken();
+ var exception = expectThrows(XContentParseException.class, () -> InferenceStringGroup.PARSER.apply(parser, null));
+ assertThat(exception.getMessage(), containsString("[InferenceStringGroup] failed to parse field [content]"));
+ }
}
public void testValue_withMoreThanOneElement_throws() {
@@ -77,6 +111,46 @@ public void testToStringList_withMoreThanOneElement_throws() {
assertThat(expectedException.getMessage(), is("Multiple-input InferenceStringGroup passed to InferenceStringGroup.toStringList"));
}
+ public void testContainsNonTextEntry_withOnlyTextInputs() {
+ var inputs = List.of(new InferenceStringGroup("string1"), new InferenceStringGroup("string2"));
+ assertThat(containsNonTextEntry(inputs), is(false));
+ }
+
+ public void testContainsNonTextEntry_withNonTextInput() {
+ DataType nonTextDataType = randomValueOtherThan(TEXT, () -> randomFrom(DataType.values()));
+ var inputs = List.of(
+ new InferenceStringGroup("string1"),
+ new InferenceStringGroup(new InferenceString(nonTextDataType, "non text"))
+ );
+ assertThat(containsNonTextEntry(inputs), is(true));
+ }
+
+ public void testIndexContainingMultipleInferenceStrings_withSingleInferenceString() {
+ var inputs = getInputsList();
+ assertThat(indexContainingMultipleInferenceStrings(inputs), nullValue());
+ }
+
+ public void testIndexContainingMultipleInferenceStrings_withMultipleInferenceStrings() {
+ var inputs = getInputsList();
+
+ // Add an InferenceStringGroup with multiple InferenceStrings at a random point in the input list
+ var indexToAdd = randomIntBetween(0, inputs.size() - 1);
+ var multipleInferenceStrings = new InferenceStringGroup(
+ List.of(new InferenceString(TEXT, "a_string"), new InferenceString(TEXT, "a_string"))
+ );
+ inputs.add(indexToAdd, multipleInferenceStrings);
+ assertThat(indexContainingMultipleInferenceStrings(inputs), is(indexToAdd));
+ }
+
+ private static ArrayList getInputsList() {
+ var listSize = randomIntBetween(1, 10);
+ var inputs = new ArrayList(listSize);
+ for (int i = 0; i < listSize; ++i) {
+ inputs.add(new InferenceStringGroup("a_string"));
+ }
+ return inputs;
+ }
+
@Override
protected InferenceStringGroup mutateInstanceForVersion(InferenceStringGroup instance, TransportVersion version) {
return instance;
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/EmbeddingActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/EmbeddingActionRequestTests.java
index 1de31682a3524..553fc6ea99b20 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/EmbeddingActionRequestTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/EmbeddingActionRequestTests.java
@@ -23,7 +23,9 @@
import java.io.IOException;
import java.util.List;
+import java.util.Map;
+import static org.elasticsearch.inference.EmbeddingRequest.JINA_AI_EMBEDDING_TASK_ADDED;
import static org.elasticsearch.xpack.core.inference.action.EmbeddingAction.Request.parseRequest;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
@@ -39,7 +41,10 @@ public void testParseRequest() throws IOException {
"content": {"type": "image", "format": "base64", "value": "some image input" }
}
],
- "input_type": "search"
+ "input_type": "search",
+ "task_settings": {
+ "field": "value"
+ }
}
""";
try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) {
@@ -53,7 +58,8 @@ public void testParseRequest() throws IOException {
taskType,
new EmbeddingRequest(
List.of(new InferenceStringGroup(new InferenceString(DataType.IMAGE, DataFormat.BASE64, "some image input"))),
- InputType.SEARCH
+ InputType.SEARCH,
+ Map.of("field", "value")
),
context,
timeout
@@ -73,7 +79,7 @@ public void testValidate_withNullEmbeddingRequestInputs_returnsValidationExcepti
var request = new EmbeddingAction.Request(
randomAlphanumericOfLength(8),
TaskType.EMBEDDING,
- new EmbeddingRequest(null, randomFrom(InputType.values())),
+ new EmbeddingRequest(null, randomFrom(InputType.values()), Map.of()),
new InferenceContext(randomAlphaOfLength(10)),
TimeValue.timeValueMillis(randomLongBetween(1, 2048))
);
@@ -87,7 +93,7 @@ public void testValidate_withEmptyEmbeddingRequestInputs_returnsValidationExcept
var request = new EmbeddingAction.Request(
randomAlphanumericOfLength(8),
TaskType.EMBEDDING,
- new EmbeddingRequest(List.of(), randomFrom(InputType.values())),
+ new EmbeddingRequest(List.of(), randomFrom(InputType.values()), Map.of()),
new InferenceContext(randomAlphaOfLength(10)),
TimeValue.timeValueMillis(randomLongBetween(1, 2048))
);
@@ -112,10 +118,11 @@ public void testValidate_withNonEmbeddingTaskType_returnsValidationException() {
}
public void testValidate_withMultipleValidationErrors_returnsAll() {
+ // Create a request with an invalid task type and null inputs
var request = new EmbeddingAction.Request(
randomAlphanumericOfLength(8),
randomValueOtherThanMany(TaskType.EMBEDDING::isAnyOrSame, () -> randomFrom(TaskType.values())),
- new EmbeddingRequest(null, randomFrom(InputType.values())),
+ new EmbeddingRequest(null, randomFrom(InputType.values()), Map.of()),
new InferenceContext(randomAlphaOfLength(10)),
TimeValue.timeValueMillis(randomLongBetween(1, 2048))
);
@@ -128,16 +135,25 @@ public void testValidate_withMultipleValidationErrors_returnsAll() {
@Override
protected EmbeddingAction.Request mutateInstanceForVersion(EmbeddingAction.Request instance, TransportVersion version) {
+ // Use empty task settings if node is on a version before Jina AI embedding task support was added, since embedding request task
+ // settings were added with that change
+ var embeddingRequest = instance.getEmbeddingRequest();
+ if (version.supports(JINA_AI_EMBEDDING_TASK_ADDED) == false) {
+ embeddingRequest = new EmbeddingRequest(embeddingRequest.inputs(), embeddingRequest.inputType(), Map.of());
+ }
+
+ var context = instance.getContext();
if (version.supports(INFERENCE_CONTEXT) == false) {
- return new EmbeddingAction.Request(
- instance.getInferenceEntityId(),
- instance.getTaskType(),
- instance.getEmbeddingRequest(),
- InferenceContext.EMPTY_INSTANCE,
- instance.getTimeout()
- );
+ context = InferenceContext.EMPTY_INSTANCE;
}
- return instance;
+
+ return new EmbeddingAction.Request(
+ instance.getInferenceEntityId(),
+ instance.getTaskType(),
+ embeddingRequest,
+ context,
+ instance.getTimeout()
+ );
}
@Override
@@ -160,7 +176,11 @@ public static EmbeddingAction.Request createRandom() {
}
private static EmbeddingRequest randomEmbeddingRequest() {
- return new EmbeddingRequest(List.of(new InferenceStringGroup(randomAlphanumericOfLength(8))), randomFrom(InputType.values()));
+ return new EmbeddingRequest(
+ List.of(new InferenceStringGroup(randomAlphanumericOfLength(8))),
+ randomFrom(InputType.values()),
+ Map.of(randomAlphanumericOfLength(8), randomAlphanumericOfLength(8))
+ );
}
@Override
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java
index a7a6f78e26b7f..771512f0d1f13 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java
@@ -15,14 +15,13 @@
import java.util.HashMap;
import java.util.Map;
+import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.DEFAULT_SETTINGS;
import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.ELASTIC_RERANKER_EXTRA_TOKEN_COUNT;
import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.ELASTIC_RERANKER_TOKEN_LIMIT;
import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder.WORDS_PER_TOKEN;
public class ChunkingSettingsBuilderTests extends ESTestCase {
- public static final SentenceBoundaryChunkingSettings DEFAULT_SETTINGS = new SentenceBoundaryChunkingSettings(250, 1);
-
public void testNullChunkingSettingsMap() {
ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.fromMap(null);
assertEquals(ChunkingSettingsBuilder.OLD_DEFAULT_SETTINGS, chunkingSettings);
diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java
index 69976cb5d6b82..b5c93171dcb29 100644
--- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java
+++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java
@@ -233,7 +233,7 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
}
public void testGetServicesWithEmbeddingTaskType() throws IOException {
- assertThat(providersFor(TaskType.EMBEDDING), containsInAnyOrder(List.of("text_embedding_test_service").toArray()));
+ assertThat(providersFor(TaskType.EMBEDDING), containsInAnyOrder(List.of("text_embedding_test_service", "jinaai").toArray()));
}
private List