Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/140323.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 140323
summary: "[Inference API] Add support for embedding task to JinaAI service"
area: Inference
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

package org.elasticsearch.inference;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand All @@ -26,9 +27,11 @@

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static java.util.Collections.singletonList;
import static org.elasticsearch.inference.ModelConfigurations.TASK_SETTINGS;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

Expand Down Expand Up @@ -62,18 +65,25 @@
* OR
* <pre>
* "input": ["first text input", "second text input"]</pre>
* @param inputs The list of {@link InferenceStringGroup} inputs to generate embeddings for
* @param inputType The {@link InputType} of the request
*
* @param inputs The list of {@link InferenceStringGroup} inputs to generate embeddings for
* @param inputType The {@link InputType} of the request
* @param taskSettings The map of task settings specific to this request
*/
public record EmbeddingRequest(List<InferenceStringGroup> inputs, InputType inputType) implements Writeable, ToXContentFragment {
public record EmbeddingRequest(List<InferenceStringGroup> inputs, InputType inputType, Map<String, Object> taskSettings)
implements
Writeable,
ToXContentFragment {

private static final String INPUT_FIELD = "input";
public static final TransportVersion JINA_AI_EMBEDDING_TASK_ADDED = TransportVersion.fromName("jina_ai_embedding_task_added");

public static final String INPUT_FIELD = "input";
private static final String INPUT_TYPE_FIELD = "input_type";

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<EmbeddingRequest, Void> PARSER = new ConstructingObjectParser<>(
EmbeddingRequest.class.getSimpleName(),
args -> new EmbeddingRequest((List<InferenceStringGroup>) args[0], (InputType) args[1])
args -> new EmbeddingRequest((List<InferenceStringGroup>) args[0], (InputType) args[1], (Map<String, Object>) args[2])
);

static {
Expand All @@ -89,31 +99,48 @@ public record EmbeddingRequest(List<InferenceStringGroup> inputs, InputType inpu
new ParseField(INPUT_TYPE_FIELD),
ObjectParser.ValueType.STRING
);
PARSER.declareField(
optionalConstructorArg(),
(parser, context) -> parser.mapOrdered(),
new ParseField(TASK_SETTINGS),
ObjectParser.ValueType.OBJECT
);
}

public static EmbeddingRequest of(List<InferenceStringGroup> contents) {
return new EmbeddingRequest(contents, null);
return new EmbeddingRequest(contents, null, null);
}

public EmbeddingRequest(List<InferenceStringGroup> inputs, @Nullable InputType inputType) {
public EmbeddingRequest(List<InferenceStringGroup> inputs, @Nullable InputType inputType, @Nullable Map<String, Object> taskSettings) {
this.inputs = inputs;
this.inputType = Objects.requireNonNullElse(inputType, InputType.UNSPECIFIED);
this.taskSettings = Objects.requireNonNullElse(taskSettings, Map.of());
}

public EmbeddingRequest(StreamInput in) throws IOException {
this(in.readCollectionAsImmutableList(InferenceStringGroup::new), in.readEnum(InputType.class));
this(
in.readCollectionAsImmutableList(InferenceStringGroup::new),
in.readEnum(InputType.class),
in.getTransportVersion().supports(JINA_AI_EMBEDDING_TASK_ADDED) ? in.readGenericMap() : Map.of()
);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(inputs);
out.writeEnum(inputType);
if (out.getTransportVersion().supports(JINA_AI_EMBEDDING_TASK_ADDED)) {
out.writeGenericMap(taskSettings);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(INPUT_FIELD, inputs);
builder.field(INPUT_TYPE_FIELD, inputType);
if (taskSettings.isEmpty() == false) {
builder.field(TASK_SETTINGS, taskSettings);
}
return builder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import java.io.IOException;
import java.util.List;
import java.util.Objects;

import static java.util.Collections.singletonList;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
Expand All @@ -37,20 +38,31 @@
* ]
* }
* </pre>
* @param inferenceStrings the list of {@link InferenceString} which should result in generating a single embedding vector
*/
public record InferenceStringGroup(List<InferenceString> inferenceStrings) implements Writeable, ToXContentObject {
private static final String CONTENT_FIELD = "content";
public final class InferenceStringGroup implements Writeable, ToXContentObject {
public static final String CONTENT_FIELD = "content";

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<InferenceStringGroup, Void> PARSER = new ConstructingObjectParser<>(
InferenceStringGroup.class.getSimpleName(),
args -> new InferenceStringGroup((List<InferenceString>) args[0])
);

static {
PARSER.declareObjectArray(constructorArg(), InferenceString.PARSER::apply, new ParseField(CONTENT_FIELD));
}

private final List<InferenceString> inferenceStrings;
private final boolean containsNonTextEntry;

/**
* @param inferenceStrings the list of {@link InferenceString} which should result in generating a single embedding vector
*/
public InferenceStringGroup(List<InferenceString> inferenceStrings) {
this.inferenceStrings = inferenceStrings;
containsNonTextEntry = inferenceStrings.stream().anyMatch(s -> s.isText() == false);
}

public InferenceStringGroup(StreamInput in) throws IOException {
this(in.readCollectionAsImmutableList(InferenceString::new));
}
Expand All @@ -64,6 +76,18 @@ public InferenceStringGroup(String input) {
this(singletonList(new InferenceString(DataType.TEXT, input)));
}

public List<InferenceString> inferenceStrings() {
return inferenceStrings;
}

public boolean containsNonTextEntry() {
return containsNonTextEntry;
}

public boolean containsMultipleInferenceStrings() {
return inferenceStrings.size() > 1;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(inferenceStrings);
Expand All @@ -81,7 +105,7 @@ public static InferenceStringGroup parse(XContentParser parser) throws IOExcepti
var token = parser.currentToken();
if (token == XContentParser.Token.VALUE_STRING) {
// Create content object from String
return new InferenceStringGroup(singletonList(new InferenceString(DataType.TEXT, parser.text())));
return new InferenceStringGroup(parser.text());
} else if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) {
// Create content object from InferenceString(s)
return InferenceStringGroup.PARSER.apply(parser, null);
Expand Down Expand Up @@ -134,4 +158,51 @@ public static List<InferenceString> toInferenceStringList(List<InferenceStringGr
public static List<String> toStringList(List<InferenceStringGroup> inferenceStringGroups) {
return InferenceString.toStringList(toInferenceStringList(inferenceStringGroups));
}

/**
* Method used to determine if a list of {@link InferenceStringGroup} contains any {@link InferenceString} that represent a non-text
* value
*
* @param inferenceStringGroups the list of {@link InferenceStringGroup} to check
* @return true if the input list contains any non-text values, false otherwise
*/
public static boolean containsNonTextEntry(List<InferenceStringGroup> inferenceStringGroups) {
return inferenceStringGroups.stream().anyMatch(InferenceStringGroup::containsNonTextEntry);
}

/**
* Method used to determine if a list of {@link InferenceStringGroup} contains any with more than one {@link InferenceString} in them
*
* @param inferenceStringGroups the list of {@link InferenceStringGroup} to check
* @return the index of the first {@link InferenceStringGroup} found to contain more than one {@link InferenceString}, or null if no
* elements in the list contain more than one {@link InferenceString}
*/
public static Integer indexContainingMultipleInferenceStrings(List<InferenceStringGroup> inferenceStringGroups) {
for (int i = 0; i < inferenceStringGroups.size(); i++) {
InferenceStringGroup inferenceStringGroup = inferenceStringGroups.get(i);
if (inferenceStringGroup.containsMultipleInferenceStrings()) {
return i;
}
}
return null;
}

@Override
public boolean equals(Object obj) {
if (obj == this) return true;
if (obj == null || obj.getClass() != this.getClass()) return false;
var that = (InferenceStringGroup) obj;
return Objects.equals(this.inferenceStrings, that.inferenceStrings);
}

@Override
public int hashCode() {
return Objects.hash(inferenceStrings);
}

@Override
public String toString() {
return "InferenceStringGroup[" + "inferenceStrings=" + inferenceStrings + ']';
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ default DenseVectorFieldMapper.ElementType elementType() {
return null;
}

default boolean isMultimodal() {
return false;
}

/**
* The model to use in the inference endpoint (e.g. text-embedding-ada-002). Sometimes the model is not defined in the service
* settings. This can happen for external providers (e.g. hugging face, azure ai studio) where the provider requires that the model
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
9259000
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.4.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
esql_vsr_converters_used,9258000
jina_ai_embedding_task_added,9259000
Loading