Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/138198.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 138198
summary: Add Embedding inference task type
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,17 @@

import java.util.List;

import static org.elasticsearch.inference.InferenceString.DataType.TEXT;

public record ChunkInferenceInput(InferenceString input, @Nullable ChunkingSettings chunkingSettings) {
public record ChunkInferenceInput(InferenceStringGroup input, @Nullable ChunkingSettings chunkingSettings) {

public ChunkInferenceInput(String input) {
this(new InferenceString(input, TEXT), null);
this(new InferenceStringGroup(input), null);
}

public static List<InferenceString> inputs(List<ChunkInferenceInput> chunkInferenceInputs) {
public static List<InferenceStringGroup> inputs(List<ChunkInferenceInput> chunkInferenceInputs) {
return chunkInferenceInputs.stream().map(ChunkInferenceInput::input).toList();
}

public String inputText() {
assert input.isText();
return input.value();
return input.textValue();
}
}
132 changes: 132 additions & 0 deletions server/src/main/java/org/elasticsearch/inference/EmbeddingRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.inference;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceString.DataFormat;
import org.elasticsearch.inference.InferenceString.DataType;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParseException;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

import static java.util.Collections.singletonList;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

/**
* This class handles the parsing of inputs used by the {@link TaskType#EMBEDDING} task type. The input for this task is specified using
* a list of "content" objects, each of which specifies the {@link DataType}, {@link DataFormat} and the String value of the input. The
* {@code format} field is optional, and if not specified will use the default {@link DataFormat} for the given {@link DataType}:
* <pre>
* "input": [
* {
* "content": {"type": "image", "format": "base64", "value": "image data"},
* },
* {
* "content": [
* {"type": "text", "value": "text input"},
* {"type": "image", "value": "image data"}
* ]
* }
* ]</pre>
* It is also possible to specify a single content object rather than a
* list:
* <pre>
* "input": {
* "content": {"type": "text", "format": "text", "value": "text input"}
* }</pre>
* To preserve input compatibility with the existing {@link TaskType#TEXT_EMBEDDING} task, the input can also be specified as a single
* String or a list of Strings, each of which will be parsed into a content object with {@link DataType} equal to
* {@link DataType#TEXT} and {@link DataFormat} equal to {@link DataFormat#TEXT}:
* <pre>
* "input": "singe text input"</pre>
* OR
* <pre>
* "input": ["first text input", "second text input"]</pre>
* @param inputs The list of {@link InferenceStringGroup} inputs to generate embeddings for
* @param inputType The {@link InputType} of the request
*/
public record EmbeddingRequest(List<InferenceStringGroup> inputs, InputType inputType) implements Writeable, ToXContentFragment {

private static final String INPUT_FIELD = "input";
private static final String INPUT_TYPE_FIELD = "input_type";

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<EmbeddingRequest, Void> PARSER = new ConstructingObjectParser<>(
EmbeddingRequest.class.getSimpleName(),
args -> new EmbeddingRequest((List<InferenceStringGroup>) args[0], (InputType) args[1])
);

static {
PARSER.declareField(
constructorArg(),
(parser, context) -> parseInput(parser),
new ParseField(INPUT_FIELD),
ObjectParser.ValueType.OBJECT_ARRAY_OR_STRING
);
PARSER.declareField(
optionalConstructorArg(),
(parser, context) -> InputType.fromString(parser.text()),
new ParseField(INPUT_TYPE_FIELD),
ObjectParser.ValueType.STRING
);
}

public static EmbeddingRequest of(List<InferenceStringGroup> contents) {
return new EmbeddingRequest(contents, null);
}

public EmbeddingRequest(List<InferenceStringGroup> inputs, @Nullable InputType inputType) {
this.inputs = inputs;
this.inputType = Objects.requireNonNullElse(inputType, InputType.UNSPECIFIED);
}

public EmbeddingRequest(StreamInput in) throws IOException {
this(in.readCollectionAsImmutableList(InferenceStringGroup::new), in.readEnum(InputType.class));
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(inputs);
out.writeEnum(inputType);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(INPUT_FIELD, inputs);
builder.field(INPUT_TYPE_FIELD, inputType);
return builder;
}

private static List<InferenceStringGroup> parseInput(XContentParser parser) throws IOException {
var token = parser.currentToken();
if (token == XContentParser.Token.VALUE_STRING || token == XContentParser.Token.START_OBJECT) {
// Single input of String or content object
return singletonList(InferenceStringGroup.parse(parser));
} else if (token == XContentParser.Token.START_ARRAY) {
// Array of String or content objects
return XContentParserUtils.parseList(parser, InferenceStringGroup::parse);
}

throw new XContentParseException("Unsupported token [" + token + "]");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ void unifiedCompletionInfer(
ActionListener<InferenceServiceResults> listener
);

/**
* Perform multimodal embedding inference on the model using the embedding schema.
*
* @param model The model
* @param request Parameters for the request
* @param timeout The timeout for the request
* @param listener Inference result listener
*/
void embeddingInfer(Model model, EmbeddingRequest request, TimeValue timeout, ActionListener<InferenceServiceResults> listener);

/**
* Chunk long text.
*
Expand Down
173 changes: 158 additions & 15 deletions server/src/main/java/org/elasticsearch/inference/InferenceString.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,55 +9,198 @@

package org.elasticsearch.inference;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Locale;
import java.util.Objects;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

/**
* This class represents a String which may be raw text, or the String representation of some other data such as an image in base64
*/
public record InferenceString(String value, DataType dataType) {
public record InferenceString(DataType dataType, DataFormat dataFormat, String value) implements Writeable, ToXContentObject {
private static final String TYPE_FIELD = "type";
private static final String FORMAT_FIELD = "format";
private static final String VALUE_FIELD = "value";

/**
* Describes the type of data represented by an {@link InferenceString}
*/
public enum DataType {
TEXT(DataFormat.TEXT),
IMAGE(DataFormat.BASE64);

private final DataFormat defaultFormat;

DataType(DataFormat defaultFormat) {
this.defaultFormat = defaultFormat;
}

@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}

public static DataType fromString(String name) {
try {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
} catch (IllegalArgumentException ex) {
throw new IllegalArgumentException(
Strings.format("Unrecognized type [%s], must be one of %s", name, Arrays.toString(DataType.values()))
);
}
}
}

/**
* Describes the format of data represented by an {@link InferenceString}
*/
public enum DataFormat {
TEXT,
IMAGE_BASE64
BASE64;

@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}

public static DataFormat fromString(String name) {
try {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
} catch (IllegalArgumentException ex) {
throw new IllegalArgumentException(
Strings.format("Unrecognized format [%s], must be one of %s", name, Arrays.toString(DataFormat.values()))
);
}
}
}

private static final EnumSet<DataType> IMAGE_TYPES = EnumSet.of(DataType.IMAGE_BASE64);
static final ConstructingObjectParser<InferenceString, Void> PARSER = new ConstructingObjectParser<>(
InferenceString.class.getSimpleName(),
args -> new InferenceString((InferenceString.DataType) args[0], (InferenceString.DataFormat) args[1], (String) args[2])
);
static {
PARSER.declareString(constructorArg(), DataType::fromString, new ParseField(TYPE_FIELD));
PARSER.declareString(optionalConstructorArg(), DataFormat::fromString, new ParseField(FORMAT_FIELD));
PARSER.declareString(constructorArg(), new ParseField(VALUE_FIELD));
}

/**
* Constructs an {@link InferenceString} with the given value and {@link DataType}
* @param value the String value
* Constructs an {@link InferenceString} with the given value and {@link DataType}, using the
* default {@link DataFormat} for the data type
*
* @param dataType the type of data that the String represents
* @param value the String value
*/
public InferenceString(String value, DataType dataType) {
this.value = Objects.requireNonNull(value);
public InferenceString(DataType dataType, String value) {
this(dataType, null, value);
}

/**
* Constructs an {@link InferenceString} with the given value, {@link DataType} and {@link DataFormat}
*
* @param dataType the type of data that the String represents
* @param dataFormat the format of the data. If {@code null}, the default data format for the given type is used
* @param value the String value
*/
public InferenceString(DataType dataType, @Nullable DataFormat dataFormat, String value) {
this.dataType = Objects.requireNonNull(dataType);
this.dataFormat = Objects.requireNonNullElse(dataFormat, this.dataType.defaultFormat);
validateTypeAndFormat();
this.value = Objects.requireNonNull(value);
}

private void validateTypeAndFormat() {
if (supportedFormatsForType(dataType).contains(dataFormat) == false) {
throw new IllegalArgumentException(
Strings.format(
"Data type [%s] does not support data format [%s], supported formats are %s",
dataType,
dataFormat,
supportedFormatsForType(dataType)
)
);
}
}

public InferenceString(StreamInput in) throws IOException {
this(in.readEnum(DataType.class), in.readEnum(DataFormat.class), in.readString());
}

public boolean isImage() {
return IMAGE_TYPES.contains(dataType);
return DataType.IMAGE.equals(dataType);
}

public boolean isText() {
return DataType.TEXT.equals(dataType);
}

public static EnumSet<DataFormat> supportedFormatsForType(DataType type) {
return switch (type) {
case TEXT -> EnumSet.of(DataFormat.TEXT);
case IMAGE -> EnumSet.of(DataFormat.BASE64);
};
}

/**
* Converts a list of {@link InferenceString} to a list of {@link String}.
* This method should only be called in code paths that do not deal with multimodal inputs; where all inputs are guaranteed to be
* raw text, since it discards the {@link org.elasticsearch.inference.InferenceString.DataType} associated with each input.
*
* <p>
* <b>
* This method should only be called in code paths that do not deal with multimodal inputs, i.e. code paths where all inputs are
* guaranteed to be raw text, since it discards the {@link org.elasticsearch.inference.InferenceString.DataType} associated with
* each input.
*</b>
* @param inferenceStrings The list of {@link InferenceString} to convert to a list of {@link String}
* @return a list of String inference inputs that do not contain any non-text inputs
*/
public static List<String> toStringList(List<InferenceString> inferenceStrings) {
return inferenceStrings.stream().map(i -> {
assert i.isText() : "Non-text input passed to InferenceString.toStringList";
return i.value();
}).toList();
return inferenceStrings.stream().map(InferenceString::textValue).toList();
}

/**
* Converts a single {@link InferenceString} to a {@link String}.
* <p>
* <b>
* This method should only be called in code paths that do not deal with multimodal inputs, i.e. code paths where all inputs are
* guaranteed to be raw text, since it discards the {@link org.elasticsearch.inference.InferenceString.DataType} associated with
* each input.
*</b>
* @param inferenceString The {@link InferenceString} to convert to a {@link String}
* @return a String inference input
*/
public static String textValue(InferenceString inferenceString) {
assert inferenceString.isText() : "Non-text input returned from InferenceString.textValue";
return inferenceString.value();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(dataType);
out.writeEnum(dataFormat);
out.writeString(value);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TYPE_FIELD, dataType);
builder.field(FORMAT_FIELD, dataFormat);
builder.field(VALUE_FIELD, value);
builder.endObject();
return builder;
}
}
Loading