Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.inference;

public class ChunkInferenceImageInput extends ChunkInferenceInput {
public ChunkInferenceImageInput(String input) {
super(input);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,28 @@

import java.util.List;

public record ChunkInferenceInput(String input, @Nullable ChunkingSettings chunkingSettings) {
public abstract class ChunkInferenceInput {
final String input;
final ChunkingSettings chunkingSettings;

public ChunkInferenceInput(String input) {
ChunkInferenceInput(String input) {
this(input, null);
}

ChunkInferenceInput(String input, @Nullable ChunkingSettings chunkingSettings) {
this.input = input;
this.chunkingSettings = chunkingSettings;
}

public String getInput() {
return input;
}

public ChunkingSettings getChunkingSettings() {
return chunkingSettings;
}

public static List<String> inputs(List<ChunkInferenceInput> chunkInferenceInputs) {
return chunkInferenceInputs.stream().map(ChunkInferenceInput::input).toList();
return chunkInferenceInputs.stream().map(ChunkInferenceInput::getInput).toList();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.inference;

public class ChunkInferenceTextInput extends ChunkInferenceInput {
public ChunkInferenceTextInput(String input, ChunkingSettings chunkingSettings) {
super(input, chunkingSettings);
}

public ChunkInferenceTextInput(String input) {
super(input);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ default boolean hideFromConfigurationApi() {
* Passing in null is specifically for query-time inference, when the timeout is managed by the
* xpack.inference.query_timeout cluster setting.
* @param listener Inference result listener
* @param imageUrls Inference input of image URLs
*/
void infer(
Model model,
Expand All @@ -120,7 +121,8 @@ void infer(
Map<String, Object> taskSettings,
InputType inputType,
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
ActionListener<InferenceServiceResults> listener,
@Nullable List<String> imageUrls
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ public String toString() {

private static void validate(TaskType taskType, Integer dimensions, SimilarityMeasure similarity, ElementType elementType) {
switch (taskType) {
case TEXT_EMBEDDING:
case TEXT_EMBEDDING, IMAGE_EMBEDDING, MULTIMODAL_EMBEDDING:
validateFieldPresent(DIMENSIONS_FIELD, dimensions, taskType);
validateFieldPresent(SIMILARITY_FIELD, similarity, taskType);
validateFieldPresent(ELEMENT_TYPE_FIELD, elementType, taskType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ public boolean isAnyOrSame(TaskType other) {
return true;
}
},
CHAT_COMPLETION;
CHAT_COMPLETION,
IMAGE_EMBEDDING,
MULTIMODAL_EMBEDDING;

public static final String NAME = "task_type";

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
9177000
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.2.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
roles_security_stats,9176000
ml_multimodal_embeddings,9177000
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public static MinimalServiceSettings randomInstance() {
SimilarityMeasure similarity = null;
DenseVectorFieldMapper.ElementType elementType = null;

if (taskType == TaskType.TEXT_EMBEDDING) {
if (taskType == TaskType.TEXT_EMBEDDING || taskType == TaskType.IMAGE_EMBEDDING || taskType == TaskType.MULTIMODAL_EMBEDDING) {
dimensions = randomIntBetween(2, 1024);
similarity = randomFrom(SimilarityMeasure.values());
elementType = randomFrom(DenseVectorFieldMapper.ElementType.values());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
Expand Down Expand Up @@ -45,6 +46,7 @@
import static org.elasticsearch.core.Strings.format;

public class InferenceAction extends ActionType<InferenceAction.Response> {
private static final TransportVersion ML_MULTIMODAL_EMBEDDINGS = TransportVersion.fromName("ml_multimodal_embeddings");

public static final InferenceAction INSTANCE = new InferenceAction();
public static final String NAME = "cluster:internal/xpack/inference";
Expand All @@ -63,6 +65,7 @@ public static class Request extends BaseInferenceActionRequest {
public static final ParseField RETURN_DOCUMENTS = new ParseField("return_documents");
public static final ParseField TOP_N = new ParseField("top_n");
public static final ParseField TIMEOUT = new ParseField("timeout");
public static final ParseField IMAGE_URL = new ParseField("image_url");

public static Builder builder(String inferenceEntityId, TaskType taskType) {
return new Builder().setInferenceEntityId(inferenceEntityId).setTaskType(taskType);
Expand All @@ -77,6 +80,7 @@ public static Builder builder(String inferenceEntityId, TaskType taskType) {
PARSER.declareBoolean(Request.Builder::setReturnDocuments, RETURN_DOCUMENTS);
PARSER.declareInt(Request.Builder::setTopN, TOP_N);
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
PARSER.declareStringArray(Builder::setImageUrl, IMAGE_URL);
}

private static final EnumSet<InputType> validEnumsBeforeUnspecifiedAdded = EnumSet.of(InputType.INGEST, InputType.SEARCH);
Expand Down Expand Up @@ -104,6 +108,7 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
private final InputType inputType;
private final TimeValue inferenceTimeout;
private final boolean stream;
private final List<String> imageUrls;

public Request(
TaskType taskType,
Expand All @@ -128,7 +133,8 @@ public Request(
inputType,
inferenceTimeout,
stream,
InferenceContext.EMPTY_INSTANCE
InferenceContext.EMPTY_INSTANCE,
null
);
}

Expand All @@ -143,7 +149,8 @@ public Request(
InputType inputType,
TimeValue inferenceTimeout,
boolean stream,
InferenceContext context
InferenceContext context,
@Nullable List<String> imageUrls
) {
super(context);
this.taskType = taskType;
Expand All @@ -156,6 +163,7 @@ public Request(
this.inputType = inputType;
this.inferenceTimeout = inferenceTimeout;
this.stream = stream;
this.imageUrls = imageUrls;
}

public Request(StreamInput in) throws IOException {
Expand Down Expand Up @@ -191,6 +199,12 @@ public Request(StreamInput in) throws IOException {
this.topN = null;
}

if (in.getTransportVersion().supports(ML_MULTIMODAL_EMBEDDINGS)) {
imageUrls = in.readOptionalStringCollectionAsList();
} else {
imageUrls = null;
}

// streaming is not supported yet for transport traffic
this.stream = false;
}
Expand Down Expand Up @@ -235,18 +249,48 @@ public boolean isStreaming() {
return stream;
}

public List<String> getImageUrls() {
return imageUrls;
}

@Override
public ActionRequestValidationException validate() {
if (input == null) {
var e = new ActionRequestValidationException();
e.addValidationError("Field [input] cannot be null");
return e;
}
if (taskType == TaskType.IMAGE_EMBEDDING) {
if (imageUrls == null) {
var e = new ActionRequestValidationException();
e.addValidationError("Field [image_url] cannot be null");
return e;
}

if (input.isEmpty()) {
var e = new ActionRequestValidationException();
e.addValidationError("Field [input] cannot be an empty array");
return e;
if (imageUrls.isEmpty()) {
var e = new ActionRequestValidationException();
e.addValidationError("Field [imageUrl] cannot be an empty array");
return e;
}
} else if (taskType == TaskType.MULTIMODAL_EMBEDDING) {
if (input == null && imageUrls == null) {
var e = new ActionRequestValidationException();
e.addValidationError("Fields [input] and [image_url] cannot both be null");
return e;
}

if (input != null && input.isEmpty() && imageUrls != null && imageUrls.isEmpty()) {
var e = new ActionRequestValidationException();
e.addValidationError("Fields [input] cannot both be empty arrays");
return e;
}
} else {
if (input == null) {
var e = new ActionRequestValidationException();
e.addValidationError("Field [input] cannot be null");
return e;
}

if (input.isEmpty()) {
var e = new ActionRequestValidationException();
e.addValidationError("Field [input] cannot be an empty array");
return e;
}
}

if (taskType.equals(TaskType.RERANK)) {
Expand All @@ -273,15 +317,15 @@ public ActionRequestValidationException validate() {
}
}

if (taskType.equals(TaskType.TEXT_EMBEDDING) || taskType.equals(TaskType.SPARSE_EMBEDDING)) {
if (isNonSparseEmbedding() || taskType.equals(TaskType.SPARSE_EMBEDDING)) {
if (query != null) {
var e = new ActionRequestValidationException();
e.addValidationError(format("Field [query] cannot be specified for task type [%s]", taskType));
return e;
}
}

if (taskType.equals(TaskType.TEXT_EMBEDDING) == false
if (isNonSparseEmbedding() == false
&& taskType.equals(TaskType.ANY) == false
&& (inputType != null && InputType.isInternalTypeOrUnspecified(inputType) == false)) {
var e = new ActionRequestValidationException();
Expand All @@ -292,6 +336,12 @@ public ActionRequestValidationException validate() {
return null;
}

private boolean isNonSparseEmbedding() {
return taskType.equals(TaskType.TEXT_EMBEDDING)
|| taskType.equals(TaskType.IMAGE_EMBEDDING)
|| taskType.equals(TaskType.MULTIMODAL_EMBEDDING);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
Expand All @@ -318,6 +368,10 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalBoolean(returnDocuments);
out.writeOptionalInt(topN);
}

if (out.getTransportVersion().supports(ML_MULTIMODAL_EMBEDDINGS)) {
out.writeOptionalStringCollection(imageUrls);
}
}

// default for easier testing
Expand Down Expand Up @@ -348,7 +402,8 @@ public boolean equals(Object o) {
&& Objects.equals(input, request.input)
&& Objects.equals(taskSettings, request.taskSettings)
&& inputType == request.inputType
&& Objects.equals(inferenceTimeout, request.inferenceTimeout);
&& Objects.equals(inferenceTimeout, request.inferenceTimeout)
&& Objects.equals(imageUrls, request.imageUrls);
}

@Override
Expand All @@ -364,7 +419,8 @@ public int hashCode() {
taskSettings,
inputType,
inferenceTimeout,
stream
stream,
imageUrls
);
}

Expand All @@ -381,6 +437,7 @@ public static class Builder {
private TimeValue timeout = DEFAULT_TIMEOUT;
private boolean stream = false;
private InferenceContext context;
private List<String> imageUrl;

private Builder() {}

Expand Down Expand Up @@ -448,6 +505,11 @@ public Builder setContext(InferenceContext context) {
return this;
}

public Builder setImageUrl(List<String> imageUrl) {
this.imageUrl = imageUrl;
return this;
}

public Request build() {
return new Request(
taskType,
Expand All @@ -460,7 +522,8 @@ public Request build() {
inputType,
timeout,
stream,
context
context,
imageUrl
);
}
}
Expand All @@ -486,6 +549,8 @@ public String toString() {
+ this.getInferenceTimeout()
+ ", context="
+ this.getContext()
+ ", imageURL="
+ this.getImageUrls()
+ ")";
}
}
Expand Down
Loading