diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkInferenceImageInput.java b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceImageInput.java new file mode 100644 index 0000000000000..ad76da81cbdd5 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceImageInput.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.inference; + +public class ChunkInferenceImageInput extends ChunkInferenceInput { + public ChunkInferenceImageInput(String input) { + super(input); + } +} diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java index 8e25e0e55f08c..0885f52511951 100644 --- a/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java +++ b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java @@ -13,13 +13,28 @@ import java.util.List; -public record ChunkInferenceInput(String input, @Nullable ChunkingSettings chunkingSettings) { +public abstract class ChunkInferenceInput { + final String input; + final ChunkingSettings chunkingSettings; - public ChunkInferenceInput(String input) { + ChunkInferenceInput(String input) { this(input, null); } + ChunkInferenceInput(String input, @Nullable ChunkingSettings chunkingSettings) { + this.input = input; + this.chunkingSettings = chunkingSettings; + } + + public String getInput() { + return input; + } + + public ChunkingSettings getChunkingSettings() { + return chunkingSettings; + } + public static List inputs(List chunkInferenceInputs) { - return chunkInferenceInputs.stream().map(ChunkInferenceInput::input).toList(); + return chunkInferenceInputs.stream().map(ChunkInferenceInput::getInput).toList(); } } diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkInferenceTextInput.java b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceTextInput.java new file mode 100644 index 0000000000000..23fa5d917b459 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceTextInput.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.inference; + +public class ChunkInferenceTextInput extends ChunkInferenceInput { + public ChunkInferenceTextInput(String input, ChunkingSettings chunkingSettings) { + super(input, chunkingSettings); + } + + public ChunkInferenceTextInput(String input) { + super(input); + } +} diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 916c65777cd30..e2a97e2008f6e 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -109,6 +109,7 @@ default boolean hideFromConfigurationApi() { * Passing in null is specifically for query-time inference, when the timeout is managed by the * xpack.inference.query_timeout cluster setting. * @param listener Inference result listener + * @param imageUrls Inference input of image URLs */ void infer( Model model, @@ -120,7 +121,8 @@ void infer( Map taskSettings, InputType inputType, @Nullable TimeValue timeout, - ActionListener listener + ActionListener listener, + @Nullable List imageUrls ); /** diff --git a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java index b9d2696b347c7..c8badce8016c9 100644 --- a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java @@ -223,7 +223,7 @@ public String toString() { private static void validate(TaskType taskType, Integer dimensions, SimilarityMeasure similarity, ElementType elementType) { switch (taskType) { - case TEXT_EMBEDDING: + case TEXT_EMBEDDING, IMAGE_EMBEDDING, MULTIMODAL_EMBEDDING: validateFieldPresent(DIMENSIONS_FIELD, dimensions, taskType); validateFieldPresent(SIMILARITY_FIELD, similarity, taskType); validateFieldPresent(ELEMENT_TYPE_FIELD, elementType, taskType); diff --git a/server/src/main/java/org/elasticsearch/inference/TaskType.java b/server/src/main/java/org/elasticsearch/inference/TaskType.java index 73a0e3cc8a774..0aeae78c80ed6 100644 --- a/server/src/main/java/org/elasticsearch/inference/TaskType.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskType.java @@ -31,7 +31,9 @@ public boolean isAnyOrSame(TaskType other) { return true; } }, - CHAT_COMPLETION; + CHAT_COMPLETION, + IMAGE_EMBEDDING, + MULTIMODAL_EMBEDDING; public static final String NAME = "task_type"; diff --git a/server/src/main/resources/transport/definitions/referable/ml_multimodal_embeddings.csv b/server/src/main/resources/transport/definitions/referable/ml_multimodal_embeddings.csv new file mode 100644 index 0000000000000..4b012195cac35 --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/ml_multimodal_embeddings.csv @@ -0,0 +1 @@ +9177000 diff --git a/server/src/main/resources/transport/upper_bounds/9.2.csv b/server/src/main/resources/transport/upper_bounds/9.2.csv index 78180d915cd67..a048b5591698d 100644 --- a/server/src/main/resources/transport/upper_bounds/9.2.csv +++ b/server/src/main/resources/transport/upper_bounds/9.2.csv @@ -1 +1 @@ -roles_security_stats,9176000 +ml_multimodal_embeddings,9177000 diff --git a/server/src/test/java/org/elasticsearch/inference/MinimalServiceSettingsTests.java b/server/src/test/java/org/elasticsearch/inference/MinimalServiceSettingsTests.java index 247fdc7a89593..8fdc15fd90af2 100644 --- a/server/src/test/java/org/elasticsearch/inference/MinimalServiceSettingsTests.java +++ b/server/src/test/java/org/elasticsearch/inference/MinimalServiceSettingsTests.java @@ -22,7 +22,7 @@ public static MinimalServiceSettings randomInstance() { SimilarityMeasure similarity = null; DenseVectorFieldMapper.ElementType elementType = null; - if (taskType == TaskType.TEXT_EMBEDDING) { + if (taskType == TaskType.TEXT_EMBEDDING || taskType == TaskType.IMAGE_EMBEDDING || taskType == TaskType.MULTIMODAL_EMBEDDING) { dimensions = randomIntBetween(2, 1024); similarity = randomFrom(SimilarityMeasure.values()); elementType = randomFrom(DenseVectorFieldMapper.ElementType.values()); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index c23996a3ce87a..713087cb89fe1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.common.xcontent.ChunkedToXContentObject; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; @@ -45,6 +46,7 @@ import static org.elasticsearch.core.Strings.format; public class InferenceAction extends ActionType { + private static final TransportVersion ML_MULTIMODAL_EMBEDDINGS = TransportVersion.fromName("ml_multimodal_embeddings"); public static final InferenceAction INSTANCE = new InferenceAction(); public static final String NAME = "cluster:internal/xpack/inference"; @@ -63,6 +65,7 @@ public static class Request extends BaseInferenceActionRequest { public static final ParseField RETURN_DOCUMENTS = new ParseField("return_documents"); public static final ParseField TOP_N = new ParseField("top_n"); public static final ParseField TIMEOUT = new ParseField("timeout"); + public static final ParseField IMAGE_URL = new ParseField("image_url"); public static Builder builder(String inferenceEntityId, TaskType taskType) { return new Builder().setInferenceEntityId(inferenceEntityId).setTaskType(taskType); @@ -77,6 +80,7 @@ public static Builder builder(String inferenceEntityId, TaskType taskType) { PARSER.declareBoolean(Request.Builder::setReturnDocuments, RETURN_DOCUMENTS); PARSER.declareInt(Request.Builder::setTopN, TOP_N); PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT); + PARSER.declareStringArray(Builder::setImageUrl, IMAGE_URL); } private static final EnumSet validEnumsBeforeUnspecifiedAdded = EnumSet.of(InputType.INGEST, InputType.SEARCH); @@ -104,6 +108,7 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType, private final InputType inputType; private final TimeValue inferenceTimeout; private final boolean stream; + private final List imageUrls; public Request( TaskType taskType, @@ -128,7 +133,8 @@ public Request( inputType, inferenceTimeout, stream, - InferenceContext.EMPTY_INSTANCE + InferenceContext.EMPTY_INSTANCE, + null ); } @@ -143,7 +149,8 @@ public Request( InputType inputType, TimeValue inferenceTimeout, boolean stream, - InferenceContext context + InferenceContext context, + @Nullable List imageUrls ) { super(context); this.taskType = taskType; @@ -156,6 +163,7 @@ public Request( this.inputType = inputType; this.inferenceTimeout = inferenceTimeout; this.stream = stream; + this.imageUrls = imageUrls; } public Request(StreamInput in) throws IOException { @@ -191,6 +199,12 @@ public Request(StreamInput in) throws IOException { this.topN = null; } + if (in.getTransportVersion().supports(ML_MULTIMODAL_EMBEDDINGS)) { + imageUrls = in.readOptionalStringCollectionAsList(); + } else { + imageUrls = null; + } + // streaming is not supported yet for transport traffic this.stream = false; } @@ -235,18 +249,48 @@ public boolean isStreaming() { return stream; } + public List getImageUrls() { + return imageUrls; + } + @Override public ActionRequestValidationException validate() { - if (input == null) { - var e = new ActionRequestValidationException(); - e.addValidationError("Field [input] cannot be null"); - return e; - } + if (taskType == TaskType.IMAGE_EMBEDDING) { + if (imageUrls == null) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [image_url] cannot be null"); + return e; + } - if (input.isEmpty()) { - var e = new ActionRequestValidationException(); - e.addValidationError("Field [input] cannot be an empty array"); - return e; + if (imageUrls.isEmpty()) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [imageUrl] cannot be an empty array"); + return e; + } + } else if (taskType == TaskType.MULTIMODAL_EMBEDDING) { + if (input == null && imageUrls == null) { + var e = new ActionRequestValidationException(); + e.addValidationError("Fields [input] and [image_url] cannot both be null"); + return e; + } + + if (input != null && input.isEmpty() && imageUrls != null && imageUrls.isEmpty()) { + var e = new ActionRequestValidationException(); + e.addValidationError("Fields [input] cannot both be empty arrays"); + return e; + } + } else { + if (input == null) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [input] cannot be null"); + return e; + } + + if (input.isEmpty()) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [input] cannot be an empty array"); + return e; + } } if (taskType.equals(TaskType.RERANK)) { @@ -273,7 +317,7 @@ public ActionRequestValidationException validate() { } } - if (taskType.equals(TaskType.TEXT_EMBEDDING) || taskType.equals(TaskType.SPARSE_EMBEDDING)) { + if (isNonSparseEmbedding() || taskType.equals(TaskType.SPARSE_EMBEDDING)) { if (query != null) { var e = new ActionRequestValidationException(); e.addValidationError(format("Field [query] cannot be specified for task type [%s]", taskType)); @@ -281,7 +325,7 @@ public ActionRequestValidationException validate() { } } - if (taskType.equals(TaskType.TEXT_EMBEDDING) == false + if (isNonSparseEmbedding() == false && taskType.equals(TaskType.ANY) == false && (inputType != null && InputType.isInternalTypeOrUnspecified(inputType) == false)) { var e = new ActionRequestValidationException(); @@ -292,6 +336,12 @@ public ActionRequestValidationException validate() { return null; } + private boolean isNonSparseEmbedding() { + return taskType.equals(TaskType.TEXT_EMBEDDING) + || taskType.equals(TaskType.IMAGE_EMBEDDING) + || taskType.equals(TaskType.MULTIMODAL_EMBEDDING); + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -318,6 +368,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalBoolean(returnDocuments); out.writeOptionalInt(topN); } + + if (out.getTransportVersion().supports(ML_MULTIMODAL_EMBEDDINGS)) { + out.writeOptionalStringCollection(imageUrls); + } } // default for easier testing @@ -348,7 +402,8 @@ public boolean equals(Object o) { && Objects.equals(input, request.input) && Objects.equals(taskSettings, request.taskSettings) && inputType == request.inputType - && Objects.equals(inferenceTimeout, request.inferenceTimeout); + && Objects.equals(inferenceTimeout, request.inferenceTimeout) + && Objects.equals(imageUrls, request.imageUrls); } @Override @@ -364,7 +419,8 @@ public int hashCode() { taskSettings, inputType, inferenceTimeout, - stream + stream, + imageUrls ); } @@ -381,6 +437,7 @@ public static class Builder { private TimeValue timeout = DEFAULT_TIMEOUT; private boolean stream = false; private InferenceContext context; + private List imageUrl; private Builder() {} @@ -448,6 +505,11 @@ public Builder setContext(InferenceContext context) { return this; } + public Builder setImageUrl(List imageUrl) { + this.imageUrl = imageUrl; + return this; + } + public Request build() { return new Request( taskType, @@ -460,7 +522,8 @@ public Request build() { inputType, timeout, stream, - context + context, + imageUrl ); } } @@ -486,6 +549,8 @@ public String toString() { + this.getInferenceTimeout() + ", context=" + this.getContext() + + ", imageURL=" + + this.getImageUrls() + ")"; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResults.java similarity index 51% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResults.java index 37fca12f1697a..d57bdde89263f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResults.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; @@ -24,9 +25,11 @@ import java.util.Objects; /** - * Writes a text embedding result in the follow json format + * Writes a dense embedding result in the follow json format. The "text_embedding" part of the array name may change depending on the + * {@link TaskType} used to generate the embedding + *
  * {
- *     "text_embedding_bytes": [
+ *     "text_embedding_bits": [
  *         {
  *             "embedding": [
  *                 23
@@ -39,17 +42,36 @@
  *         }
  *     ]
  * }
+ * 
*/ // Note: inheriting from TextEmbeddingByteResults gives a bad implementation of the // Embedding.merge method for bits. TODO: implement a proper merge method -public record TextEmbeddingBitResults(List embeddings) - implements - TextEmbeddingResults { +public final class DenseEmbeddingBitResults implements DenseEmbeddingResults { public static final String NAME = "text_embedding_service_bit_results"; - public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits"; + public static final String BITS_SUFFIX = "_bits"; + private final List embeddings; + private final String arrayName; - public TextEmbeddingBitResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(TextEmbeddingByteResults.Embedding::new)); + public DenseEmbeddingBitResults(List embeddings) { + this(embeddings, TEXT_EMBEDDING); + } + + public DenseEmbeddingBitResults(List embeddings, String taskName) { + this.embeddings = embeddings; + this.arrayName = getArrayNameFromTaskName(taskName); + } + + public DenseEmbeddingBitResults(StreamInput in) throws IOException { + embeddings = in.readCollectionAsList(DenseEmbeddingByteResults.Embedding::new); + if (in.getTransportVersion().supports(ML_MULTIMODAL_EMBEDDINGS)) { + arrayName = in.readString(); + } else { + arrayName = getArrayNameFromTaskName(TEXT_EMBEDDING); + } + } + + public static String getArrayNameFromTaskName(String taskName) { + return taskName + BITS_SUFFIX; } @Override @@ -63,12 +85,15 @@ public int getFirstEmbeddingSize() { @Override public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContentHelper.array(TEXT_EMBEDDING_BITS, embeddings.iterator()); + return ChunkedToXContentHelper.array(arrayName, embeddings.iterator()); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeCollection(embeddings); + if (out.getTransportVersion().supports(ML_MULTIMODAL_EMBEDDINGS)) { + out.writeString(arrayName); + } } @Override @@ -78,14 +103,12 @@ public String getWriteableName() { @Override public List transformToCoordinationFormat() { - return embeddings.stream() - .map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false)) - .toList(); + return embeddings.stream().map(embedding -> new MlTextEmbeddingResults(arrayName, embedding.toDoubleArray(), false)).toList(); } public Map asMap() { Map map = new LinkedHashMap<>(); - map.put(TEXT_EMBEDDING_BITS, embeddings); + map.put(arrayName, embeddings); return map; } @@ -94,12 +117,23 @@ public Map asMap() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - TextEmbeddingBitResults that = (TextEmbeddingBitResults) o; - return Objects.equals(embeddings, that.embeddings); + DenseEmbeddingBitResults that = (DenseEmbeddingBitResults) o; + return Objects.equals(embeddings, that.embeddings) && Objects.equals(arrayName, that.arrayName); } @Override public int hashCode() { - return Objects.hash(embeddings); + return Objects.hash(embeddings, arrayName); } + + @Override + public List embeddings() { + return embeddings; + } + + @Override + public String toString() { + return "DenseEmbeddingBitResults[" + "embeddings=" + embeddings + ", " + "arrayName=" + arrayName + ']'; + } + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResults.java similarity index 75% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResults.java index 54f858cb20ae0..5b0e7c556bcce 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResults.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContent; @@ -31,7 +32,9 @@ import java.util.Objects; /** - * Writes a text embedding result in the follow json format + * Writes a dense embedding result in the follow json format. The "text_embedding" part of the array name may change depending on the + * {@link TaskType} used to generate the embedding + *
  * {
  *     "text_embedding_bytes": [
  *         {
@@ -46,13 +49,34 @@
  *         }
  *     ]
  * }
+ * 
*/ -public record TextEmbeddingByteResults(List embeddings) implements TextEmbeddingResults { +public final class DenseEmbeddingByteResults implements DenseEmbeddingResults { public static final String NAME = "text_embedding_service_byte_results"; - public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes"; + public static final String BYTES_SUFFIX = "_bytes"; + private final List embeddings; + private final String arrayName; - public TextEmbeddingByteResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(TextEmbeddingByteResults.Embedding::new)); + public DenseEmbeddingByteResults(List embeddings) { + this(embeddings, TEXT_EMBEDDING); + } + + public DenseEmbeddingByteResults(List embeddings, String taskName) { + this.embeddings = embeddings; + this.arrayName = getArrayNameFromTaskName(taskName); + } + + public DenseEmbeddingByteResults(StreamInput in) throws IOException { + embeddings = in.readCollectionAsList(Embedding::new); + if (in.getTransportVersion().supports(ML_MULTIMODAL_EMBEDDINGS)) { + arrayName = in.readString(); + } else { + arrayName = getArrayNameFromTaskName(TEXT_EMBEDDING); + } + } + + public static String getArrayNameFromTaskName(String taskName) { + return taskName + BYTES_SUFFIX; } @Override @@ -65,12 +89,15 @@ public int getFirstEmbeddingSize() { @Override public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContentHelper.array(TEXT_EMBEDDING_BYTES, embeddings.iterator()); + return ChunkedToXContentHelper.array(arrayName, embeddings.iterator()); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeCollection(embeddings); + if (out.getTransportVersion().supports(ML_MULTIMODAL_EMBEDDINGS)) { + out.writeString(arrayName); + } } @Override @@ -80,14 +107,12 @@ public String getWriteableName() { @Override public List transformToCoordinationFormat() { - return embeddings.stream() - .map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BYTES, embedding.toDoubleArray(), false)) - .toList(); + return embeddings.stream().map(embedding -> new MlTextEmbeddingResults(arrayName, embedding.toDoubleArray(), false)).toList(); } public Map asMap() { Map map = new LinkedHashMap<>(); - map.put(TEXT_EMBEDDING_BYTES, embeddings); + map.put(arrayName, embeddings); return map; } @@ -96,13 +121,23 @@ public Map asMap() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - TextEmbeddingByteResults that = (TextEmbeddingByteResults) o; - return Objects.equals(embeddings, that.embeddings); + DenseEmbeddingByteResults that = (DenseEmbeddingByteResults) o; + return Objects.equals(embeddings, that.embeddings) && Objects.equals(arrayName, that.arrayName); } @Override public int hashCode() { - return Objects.hash(embeddings); + return Objects.hash(embeddings, arrayName); + } + + @Override + public List embeddings() { + return embeddings; + } + + @Override + public String toString() { + return "DenseEmbeddingByteResults[" + "embeddings=" + embeddings + ", " + "arrayName=" + arrayName + ']'; } // Note: the field "numberOfMergedEmbeddings" is not serialized, so merging diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResults.java similarity index 75% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResults.java index e68a5e4bd13b0..a67a6ff736fb4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResults.java @@ -23,6 +23,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContent; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import java.io.IOException; @@ -36,7 +37,9 @@ import java.util.stream.Collectors; /** - * Writes a text embedding result in the follow json format + * Writes a dense embedding result in the follow json format. The "text_embedding" array name may change depending on the + * {@link TaskType} used to generate the embedding + *
  * {
  *     "text_embedding": [
  *         {
@@ -51,31 +54,48 @@
  *         }
  *     ]
  * }
+ * 
*/ -public record TextEmbeddingFloatResults(List embeddings) implements TextEmbeddingResults { +public final class DenseEmbeddingFloatResults implements DenseEmbeddingResults { public static final String NAME = "text_embedding_service_results"; - public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString(); + private final List embeddings; + private final String arrayName; - public TextEmbeddingFloatResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(TextEmbeddingFloatResults.Embedding::new)); + public DenseEmbeddingFloatResults(List embeddings) { + this(embeddings, TEXT_EMBEDDING); + } + + public DenseEmbeddingFloatResults(List embeddings, String taskName) { + this.embeddings = embeddings; + this.arrayName = taskName; + } + + public DenseEmbeddingFloatResults(StreamInput in) throws IOException { + embeddings = in.readCollectionAsList(Embedding::new); + if (in.getTransportVersion().supports(ML_MULTIMODAL_EMBEDDINGS)) { + arrayName = in.readString(); + } else { + arrayName = TEXT_EMBEDDING; + } } @SuppressWarnings("deprecation") - TextEmbeddingFloatResults(LegacyTextEmbeddingResults legacyTextEmbeddingResults) { + DenseEmbeddingFloatResults(LegacyTextEmbeddingResults legacyTextEmbeddingResults) { this( legacyTextEmbeddingResults.embeddings() .stream() .map(embedding -> new Embedding(embedding.values())) - .collect(Collectors.toList()) + .collect(Collectors.toList()), + TEXT_EMBEDDING ); } - public static TextEmbeddingFloatResults of(List results) { + public static DenseEmbeddingFloatResults of(List results) { List embeddings = new ArrayList<>(results.size()); for (InferenceResults result : results) { if (result instanceof MlTextEmbeddingResults embeddingResult) { - embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingResult)); - } else if (result instanceof org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults errorResult) { + embeddings.add(Embedding.of(embeddingResult)); + } else if (result instanceof ErrorInferenceResults errorResult) { if (errorResult.getException() instanceof ElasticsearchStatusException statusException) { throw statusException; } else { @@ -91,7 +111,7 @@ public static TextEmbeddingFloatResults of(List resu ); } } - return new TextEmbeddingFloatResults(embeddings); + return new DenseEmbeddingFloatResults(embeddings); } @Override @@ -104,12 +124,15 @@ public int getFirstEmbeddingSize() { @Override public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContentHelper.array(TEXT_EMBEDDING, embeddings.iterator()); + return ChunkedToXContentHelper.array(arrayName, embeddings.iterator()); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeCollection(embeddings); + if (out.getTransportVersion().supports(ML_MULTIMODAL_EMBEDDINGS)) { + out.writeString(arrayName); + } } @Override @@ -119,12 +142,12 @@ public String getWriteableName() { @Override public List transformToCoordinationFormat() { - return embeddings.stream().map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING, embedding.asDoubleArray(), false)).toList(); + return embeddings.stream().map(embedding -> new MlTextEmbeddingResults(arrayName, embedding.asDoubleArray(), false)).toList(); } public Map asMap() { Map map = new LinkedHashMap<>(); - map.put(TEXT_EMBEDDING, embeddings); + map.put(arrayName, embeddings); return map; } @@ -133,13 +156,23 @@ public Map asMap() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - TextEmbeddingFloatResults that = (TextEmbeddingFloatResults) o; - return Objects.equals(embeddings, that.embeddings); + DenseEmbeddingFloatResults that = (DenseEmbeddingFloatResults) o; + return Objects.equals(embeddings, that.embeddings) && Objects.equals(arrayName, that.arrayName); } @Override public int hashCode() { - return Objects.hash(embeddings); + return Objects.hash(embeddings, arrayName); + } + + @Override + public List embeddings() { + return embeddings; + } + + @Override + public String toString() { + return "TextEmbeddingFloatResults[" + "embeddings=" + embeddings + ", " + "taskName=" + arrayName + ']'; } // Note: the field "numberOfMergedEmbeddings" is not serialized, so merging diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingResults.java similarity index 61% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingResults.java index ea4e45ec67407..869e050718bea 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingResults.java @@ -7,7 +7,12 @@ package org.elasticsearch.xpack.core.inference.results; -public interface TextEmbeddingResults> extends EmbeddingResults { +import org.elasticsearch.TransportVersion; +import org.elasticsearch.inference.TaskType; + +public interface DenseEmbeddingResults> extends EmbeddingResults { + TransportVersion ML_MULTIMODAL_EMBEDDINGS = TransportVersion.fromName("ml_multimodal_embeddings"); + String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString(); /** * Returns the first text embedding entry in the result list's array size. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java index 60bbeb624b532..1038bddf08baa 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java @@ -44,7 +44,7 @@ * * Legacy text embedding results represents what was returned prior to the * {@link org.elasticsearch.TransportVersions#V_8_12_0} version. - * @deprecated use {@link TextEmbeddingFloatResults} instead + * @deprecated use {@link DenseEmbeddingFloatResults} instead */ @Deprecated public record LegacyTextEmbeddingResults(List embeddings) implements InferenceResults { @@ -114,8 +114,8 @@ public int hashCode() { return Objects.hash(embeddings); } - public TextEmbeddingFloatResults transformToTextEmbeddingResults() { - return new TextEmbeddingFloatResults(this); + public DenseEmbeddingFloatResults transformToTextEmbeddingResults() { + return new DenseEmbeddingFloatResults(this); } public record Embedding(float[] values) implements Writeable, ToXContentObject { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index a1451efaa30ce..ff740979e9d8b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -51,7 +51,8 @@ protected InferenceAction.Request createTestInstance() { randomFrom(InputType.values()), TimeValue.timeValueMillis(randomLongBetween(1, 2048)), false, - new InferenceContext(randomAlphanumericOfLength(10)) + new InferenceContext(randomAlphanumericOfLength(10)), + null ); } @@ -499,7 +500,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInputType(), instance.getInferenceTimeout(), false, - instance.getContext() + instance.getContext(), + null ); } case 1 -> new InferenceAction.Request( @@ -513,7 +515,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInputType(), instance.getInferenceTimeout(), false, - instance.getContext() + instance.getContext(), + null ); case 2 -> { var changedInputs = new ArrayList(instance.getInput()); @@ -529,7 +532,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInputType(), instance.getInferenceTimeout(), false, - instance.getContext() + instance.getContext(), + null ); } case 3 -> { @@ -551,7 +555,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInputType(), instance.getInferenceTimeout(), false, - instance.getContext() + instance.getContext(), + null ); } case 4 -> { @@ -567,7 +572,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc nextInputType, instance.getInferenceTimeout(), false, - instance.getContext() + instance.getContext(), + null ); } case 5 -> new InferenceAction.Request( @@ -581,7 +587,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInputType(), instance.getInferenceTimeout(), false, - instance.getContext() + instance.getContext(), + null ); case 6 -> { var newDuration = Duration.of( @@ -600,7 +607,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInputType(), TimeValue.timeValueMillis(newDuration.plus(additionalTime).toMillis()), false, - instance.getContext() + instance.getContext(), + null ); } case 7 -> { @@ -616,7 +624,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getInputType(), instance.getInferenceTimeout(), instance.isStreaming(), - newContext + newContext, + null ); } default -> throw new UnsupportedOperationException(); @@ -709,7 +718,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getInputType(), instance.getInferenceTimeout(), false, - InferenceContext.EMPTY_INSTANCE + InferenceContext.EMPTY_INSTANCE, + null ); } else if (version.before(TransportVersions.RERANK_COMMON_OPTIONS_ADDED) && version.isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19) == false) { @@ -724,7 +734,8 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getInputType(), instance.getInferenceTimeout(), false, - instance.getContext() + instance.getContext(), + null ); } else { mutated = instance; @@ -842,7 +853,8 @@ public void testWriteTo_WhenVersionIsBeforeInferenceContext_ShouldSetContextToEm InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, false, - new InferenceContext(randomAlphaOfLength(10)) + new InferenceContext(randomAlphaOfLength(10)), + null ); InferenceAction.Request deserializedInstance = copyWriteable( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java index 61b49075702a2..beaac57dace1d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java @@ -19,19 +19,19 @@ import static org.hamcrest.Matchers.is; -public class TextEmbeddingBitResultsTests extends AbstractWireSerializingTestCase { - public static TextEmbeddingBitResults createRandomResults() { +public class TextEmbeddingBitResultsTests extends AbstractWireSerializingTestCase { + public static DenseEmbeddingBitResults createRandomResults() { int embeddings = randomIntBetween(1, 10); - List embeddingResults = new ArrayList<>(embeddings); + List embeddingResults = new ArrayList<>(embeddings); for (int i = 0; i < embeddings; i++) { embeddingResults.add(createRandomEmbedding()); } - return new TextEmbeddingBitResults(embeddingResults); + return new DenseEmbeddingBitResults(embeddingResults); } - private static TextEmbeddingByteResults.Embedding createRandomEmbedding() { + private static DenseEmbeddingByteResults.Embedding createRandomEmbedding() { int columns = randomIntBetween(1, 10); byte[] bytes = new byte[columns]; @@ -39,11 +39,11 @@ private static TextEmbeddingByteResults.Embedding createRandomEmbedding() { bytes[i] = randomByte(); } - return new TextEmbeddingByteResults.Embedding(bytes); + return new DenseEmbeddingByteResults.Embedding(bytes); } public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { - var entity = new TextEmbeddingBitResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }))); + var entity = new DenseEmbeddingBitResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }))); String xContentResult = Strings.toString(entity, true, true); assertThat(xContentResult, is(""" @@ -59,10 +59,10 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE } public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException { - var entity = new TextEmbeddingBitResults( + var entity = new DenseEmbeddingBitResults( List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 24 }) + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }), + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 24 }) ) ); @@ -85,10 +85,10 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I } public void testTransformToCoordinationFormat() { - var results = new TextEmbeddingBitResults( + var results = new DenseEmbeddingBitResults( List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) ) ).transformToCoordinationFormat(); @@ -96,18 +96,26 @@ public void testTransformToCoordinationFormat() { results, is( List.of( - new MlTextEmbeddingResults(TextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 23F, 24F }, false), - new MlTextEmbeddingResults(TextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 25F, 26F }, false) + new MlTextEmbeddingResults( + DenseEmbeddingBitResults.getArrayNameFromTaskName(DenseEmbeddingResults.TEXT_EMBEDDING), + new double[] { 23F, 24F }, + false + ), + new MlTextEmbeddingResults( + DenseEmbeddingBitResults.getArrayNameFromTaskName(DenseEmbeddingResults.TEXT_EMBEDDING), + new double[] { 25F, 26F }, + false + ) ) ) ); } public void testGetFirstEmbeddingSize() { - var firstEmbeddingSize = new TextEmbeddingBitResults( + var firstEmbeddingSize = new DenseEmbeddingBitResults( List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) ) ).getFirstEmbeddingSize(); @@ -115,33 +123,33 @@ public void testGetFirstEmbeddingSize() { } @Override - protected Writeable.Reader instanceReader() { - return TextEmbeddingBitResults::new; + protected Writeable.Reader instanceReader() { + return DenseEmbeddingBitResults::new; } @Override - protected TextEmbeddingBitResults createTestInstance() { + protected DenseEmbeddingBitResults createTestInstance() { return createRandomResults(); } @Override - protected TextEmbeddingBitResults mutateInstance(TextEmbeddingBitResults instance) throws IOException { + protected DenseEmbeddingBitResults mutateInstance(DenseEmbeddingBitResults instance) throws IOException { // if true we reduce the embeddings list by a random amount, if false we add an embedding to the list if (randomBoolean()) { // -1 to remove at least one item from the list int end = randomInt(instance.embeddings().size() - 1); - return new TextEmbeddingBitResults(instance.embeddings().subList(0, end)); + return new DenseEmbeddingBitResults(instance.embeddings().subList(0, end)); } else { - List embeddings = new ArrayList<>(instance.embeddings()); + List embeddings = new ArrayList<>(instance.embeddings()); embeddings.add(createRandomEmbedding()); - return new TextEmbeddingBitResults(embeddings); + return new DenseEmbeddingBitResults(embeddings); } } public static Map buildExpectationByte(List> embeddings) { return Map.of( - TextEmbeddingBitResults.TEXT_EMBEDDING_BITS, - embeddings.stream().map(embedding -> Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList() + DenseEmbeddingBitResults.getArrayNameFromTaskName(DenseEmbeddingResults.TEXT_EMBEDDING), + embeddings.stream().map(embedding -> Map.of(DenseEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList() ); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java index 60f45399cfb32..191b3683056f7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java @@ -20,19 +20,19 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; -public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase { - public static TextEmbeddingByteResults createRandomResults() { +public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase { + public static DenseEmbeddingByteResults createRandomResults() { int embeddings = randomIntBetween(1, 10); - List embeddingResults = new ArrayList<>(embeddings); + List embeddingResults = new ArrayList<>(embeddings); for (int i = 0; i < embeddings; i++) { embeddingResults.add(createRandomEmbedding()); } - return new TextEmbeddingByteResults(embeddingResults); + return new DenseEmbeddingByteResults(embeddingResults); } - private static TextEmbeddingByteResults.Embedding createRandomEmbedding() { + private static DenseEmbeddingByteResults.Embedding createRandomEmbedding() { int columns = randomIntBetween(1, 10); byte[] bytes = new byte[columns]; @@ -40,11 +40,11 @@ private static TextEmbeddingByteResults.Embedding createRandomEmbedding() { bytes[i] = randomByte(); } - return new TextEmbeddingByteResults.Embedding(bytes); + return new DenseEmbeddingByteResults.Embedding(bytes); } public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { - var entity = new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }))); + var entity = new DenseEmbeddingByteResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }))); String xContentResult = Strings.toString(entity, true, true); assertThat(xContentResult, is(""" @@ -60,10 +60,10 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE } public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException { - var entity = new TextEmbeddingByteResults( + var entity = new DenseEmbeddingByteResults( List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 24 }) + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }), + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 24 }) ) ); @@ -86,10 +86,10 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I } public void testTransformToCoordinationFormat() { - var results = new TextEmbeddingByteResults( + var results = new DenseEmbeddingByteResults( List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) ) ).transformToCoordinationFormat(); @@ -97,18 +97,26 @@ public void testTransformToCoordinationFormat() { results, is( List.of( - new MlTextEmbeddingResults(TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 23F, 24F }, false), - new MlTextEmbeddingResults(TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 25F, 26F }, false) + new MlTextEmbeddingResults( + DenseEmbeddingByteResults.getArrayNameFromTaskName(DenseEmbeddingResults.TEXT_EMBEDDING), + new double[] { 23F, 24F }, + false + ), + new MlTextEmbeddingResults( + DenseEmbeddingByteResults.getArrayNameFromTaskName(DenseEmbeddingResults.TEXT_EMBEDDING), + new double[] { 25F, 26F }, + false + ) ) ) ); } public void testGetFirstEmbeddingSize() { - var firstEmbeddingSize = new TextEmbeddingByteResults( + var firstEmbeddingSize = new DenseEmbeddingByteResults( List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) ) ).getFirstEmbeddingSize(); @@ -116,43 +124,43 @@ public void testGetFirstEmbeddingSize() { } public void testEmbeddingMerge() { - TextEmbeddingByteResults.Embedding embedding1 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, -128 }); - TextEmbeddingByteResults.Embedding embedding2 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 127 }); - TextEmbeddingByteResults.Embedding embedding3 = new TextEmbeddingByteResults.Embedding(new byte[] { 0, 0, 100 }); - TextEmbeddingByteResults.Embedding mergedEmbedding = embedding1.merge(embedding2); - assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, 0 }))); + DenseEmbeddingByteResults.Embedding embedding1 = new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 1, -128 }); + DenseEmbeddingByteResults.Embedding embedding2 = new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 0, 127 }); + DenseEmbeddingByteResults.Embedding embedding3 = new DenseEmbeddingByteResults.Embedding(new byte[] { 0, 0, 100 }); + DenseEmbeddingByteResults.Embedding mergedEmbedding = embedding1.merge(embedding2); + assertThat(mergedEmbedding, equalTo(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 1, 0 }))); mergedEmbedding = mergedEmbedding.merge(embedding3); - assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 33 }))); + assertThat(mergedEmbedding, equalTo(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 0, 33 }))); } @Override - protected Writeable.Reader instanceReader() { - return TextEmbeddingByteResults::new; + protected Writeable.Reader instanceReader() { + return DenseEmbeddingByteResults::new; } @Override - protected TextEmbeddingByteResults createTestInstance() { + protected DenseEmbeddingByteResults createTestInstance() { return createRandomResults(); } @Override - protected TextEmbeddingByteResults mutateInstance(TextEmbeddingByteResults instance) throws IOException { + protected DenseEmbeddingByteResults mutateInstance(DenseEmbeddingByteResults instance) throws IOException { // if true we reduce the embeddings list by a random amount, if false we add an embedding to the list if (randomBoolean()) { // -1 to remove at least one item from the list int end = randomInt(instance.embeddings().size() - 1); - return new TextEmbeddingByteResults(instance.embeddings().subList(0, end)); + return new DenseEmbeddingByteResults(instance.embeddings().subList(0, end)); } else { - List embeddings = new ArrayList<>(instance.embeddings()); + List embeddings = new ArrayList<>(instance.embeddings()); embeddings.add(createRandomEmbedding()); - return new TextEmbeddingByteResults(embeddings); + return new DenseEmbeddingByteResults(embeddings); } } public static Map buildExpectationByte(List> embeddings) { return Map.of( - TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, - embeddings.stream().map(embedding -> Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList() + DenseEmbeddingByteResults.getArrayNameFromTaskName(DenseEmbeddingResults.TEXT_EMBEDDING), + embeddings.stream().map(embedding -> Map.of(DenseEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList() ); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java index 8cdd98bcdebc6..4252c65a430de 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java @@ -20,30 +20,30 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; -public class TextEmbeddingFloatResultsTests extends AbstractWireSerializingTestCase { - public static TextEmbeddingFloatResults createRandomResults() { +public class TextEmbeddingFloatResultsTests extends AbstractWireSerializingTestCase { + public static DenseEmbeddingFloatResults createRandomResults() { int embeddings = randomIntBetween(1, 10); - List embeddingResults = new ArrayList<>(embeddings); + List embeddingResults = new ArrayList<>(embeddings); for (int i = 0; i < embeddings; i++) { embeddingResults.add(createRandomEmbedding()); } - return new TextEmbeddingFloatResults(embeddingResults); + return new DenseEmbeddingFloatResults(embeddingResults); } - private static TextEmbeddingFloatResults.Embedding createRandomEmbedding() { + private static DenseEmbeddingFloatResults.Embedding createRandomEmbedding() { int columns = randomIntBetween(1, 10); float[] floats = new float[columns]; for (int i = 0; i < columns; i++) { floats[i] = randomFloat(); } - return new TextEmbeddingFloatResults.Embedding(floats); + return new DenseEmbeddingFloatResults.Embedding(floats); } public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { - var entity = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F }))); + var entity = new DenseEmbeddingFloatResults(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F }))); String xContentResult = Strings.toString(entity, true, true); assertThat(xContentResult, is(""" @@ -59,10 +59,10 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE } public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException { - var entity = new TextEmbeddingFloatResults( + var entity = new DenseEmbeddingFloatResults( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.2F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.2F }) ) ); @@ -86,10 +86,10 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I } public void testTransformToCoordinationFormat() { - var results = new TextEmbeddingFloatResults( + var results = new DenseEmbeddingFloatResults( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F }) ) ).transformToCoordinationFormat(); @@ -97,18 +97,18 @@ public void testTransformToCoordinationFormat() { results, is( List.of( - new MlTextEmbeddingResults(TextEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.1F, 0.2F }, false), - new MlTextEmbeddingResults(TextEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.3F, 0.4F }, false) + new MlTextEmbeddingResults(DenseEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.1F, 0.2F }, false), + new MlTextEmbeddingResults(DenseEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.3F, 0.4F }, false) ) ) ); } public void testGetFirstEmbeddingSize() { - var firstEmbeddingSize = new TextEmbeddingFloatResults( + var firstEmbeddingSize = new DenseEmbeddingFloatResults( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F }) ) ).getFirstEmbeddingSize(); @@ -116,51 +116,54 @@ public void testGetFirstEmbeddingSize() { } public void testEmbeddingMerge() { - TextEmbeddingFloatResults.Embedding embedding1 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.2f, 0.3f, 0.4f }); - TextEmbeddingFloatResults.Embedding embedding2 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.0f, 0.4f, 0.1f, 1.0f }); - TextEmbeddingFloatResults.Embedding embedding3 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.2f, 0.9f, 0.8f, 0.1f }); - TextEmbeddingFloatResults.Embedding mergedEmbedding = embedding1.merge(embedding2); - assertThat(mergedEmbedding, equalTo(new TextEmbeddingFloatResults.Embedding(new float[] { 0.05f, 0.3f, 0.2f, 0.7f }))); + DenseEmbeddingFloatResults.Embedding embedding1 = new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.2f, 0.3f, 0.4f }); + DenseEmbeddingFloatResults.Embedding embedding2 = new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0f, 0.4f, 0.1f, 1.0f }); + DenseEmbeddingFloatResults.Embedding embedding3 = new DenseEmbeddingFloatResults.Embedding(new float[] { 0.2f, 0.9f, 0.8f, 0.1f }); + DenseEmbeddingFloatResults.Embedding mergedEmbedding = embedding1.merge(embedding2); + assertThat(mergedEmbedding, equalTo(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.05f, 0.3f, 0.2f, 0.7f }))); mergedEmbedding = mergedEmbedding.merge(embedding3); - assertThat(mergedEmbedding, equalTo(new TextEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.5f, 0.4f, 0.5f }))); + assertThat(mergedEmbedding, equalTo(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.5f, 0.4f, 0.5f }))); } @Override - protected Writeable.Reader instanceReader() { - return TextEmbeddingFloatResults::new; + protected Writeable.Reader instanceReader() { + return DenseEmbeddingFloatResults::new; } @Override - protected TextEmbeddingFloatResults createTestInstance() { + protected DenseEmbeddingFloatResults createTestInstance() { return createRandomResults(); } @Override - protected TextEmbeddingFloatResults mutateInstance(TextEmbeddingFloatResults instance) throws IOException { + protected DenseEmbeddingFloatResults mutateInstance(DenseEmbeddingFloatResults instance) throws IOException { // if true we reduce the embeddings list by a random amount, if false we add an embedding to the list if (randomBoolean()) { // -1 to remove at least one item from the list int end = randomInt(instance.embeddings().size() - 1); - return new TextEmbeddingFloatResults(instance.embeddings().subList(0, end)); + return new DenseEmbeddingFloatResults(instance.embeddings().subList(0, end)); } else { - List embeddings = new ArrayList<>(instance.embeddings()); + List embeddings = new ArrayList<>(instance.embeddings()); embeddings.add(createRandomEmbedding()); - return new TextEmbeddingFloatResults(embeddings); + return new DenseEmbeddingFloatResults(embeddings); } } public static Map buildExpectationFloat(List embeddings) { - return Map.of(TextEmbeddingFloatResults.TEXT_EMBEDDING, embeddings.stream().map(TextEmbeddingFloatResults.Embedding::new).toList()); + return Map.of( + DenseEmbeddingFloatResults.TEXT_EMBEDDING, + embeddings.stream().map(DenseEmbeddingFloatResults.Embedding::new).toList() + ); } public static Map buildExpectationByte(List embeddings) { return Map.of( - TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, - embeddings.stream().map(TextEmbeddingByteResults.Embedding::new).toList() + DenseEmbeddingByteResults.getArrayNameFromTaskName(DenseEmbeddingResults.TEXT_EMBEDDING), + embeddings.stream().map(DenseEmbeddingByteResults.Embedding::new).toList() ); } public static Map buildExpectationBinary(List embeddings) { - return Map.of("text_embedding_bits", embeddings.stream().map(TextEmbeddingByteResults.Embedding::new).toList()); + return Map.of("text_embedding_bits", embeddings.stream().map(DenseEmbeddingByteResults.Embedding::new).toList()); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java index a2b0a32e77b05..315e2af38a76b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java @@ -12,14 +12,14 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Releasables; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults; import org.elasticsearch.xpack.esql.inference.InferenceOperator; /** * {@link TextEmbeddingOperatorOutputBuilder} builds the output page for text embedding by converting - * {@link TextEmbeddingResults} into a {@link FloatBlock} containing dense vector embeddings. + * {@link DenseEmbeddingResults} into a {@link FloatBlock} containing dense vector embeddings. */ class TextEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder { private final Page inputPage; @@ -39,7 +39,7 @@ public void close() { * Adds an inference response to the output builder. * *

- * If the response is null or not of type {@link TextEmbeddingResults} an {@link IllegalStateException} is thrown. + * If the response is null or not of type {@link DenseEmbeddingResults} an {@link IllegalStateException} is thrown. * Else, the embedding vector is added to the output block as a multi-value position. *

* @@ -55,7 +55,7 @@ public void addInferenceResponse(InferenceAction.Response inferenceResponse) { return; } - TextEmbeddingResults embeddingResults = inferenceResults(inferenceResponse); + DenseEmbeddingResults embeddingResults = inferenceResults(inferenceResponse); var embeddings = embeddingResults.embeddings(); if (embeddings.isEmpty()) { @@ -82,17 +82,17 @@ public Page buildOutput() { return inputPage.appendBlock(outputBlock); } - private TextEmbeddingResults inferenceResults(InferenceAction.Response inferenceResponse) { - return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, TextEmbeddingResults.class); + private DenseEmbeddingResults inferenceResults(InferenceAction.Response inferenceResponse) { + return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, DenseEmbeddingResults.class); } /** * Extracts the embedding as a float array from the embedding result. */ - private static float[] getEmbeddingAsFloatArray(TextEmbeddingResults embedding) { + private static float[] getEmbeddingAsFloatArray(DenseEmbeddingResults embedding) { return switch (embedding.embeddings().get(0)) { - case TextEmbeddingFloatResults.Embedding floatEmbedding -> floatEmbedding.values(); - case TextEmbeddingByteResults.Embedding byteEmbedding -> toFloatArray(byteEmbedding.values()); + case DenseEmbeddingFloatResults.Embedding floatEmbedding -> floatEmbedding.values(); + case DenseEmbeddingByteResults.Embedding byteEmbedding -> toFloatArray(byteEmbedding.values()); default -> throw new IllegalArgumentException( "Unsupported embedding type: " + embedding.embeddings().get(0).getClass().getName() diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java index ea77c6bed3c38..0b006574190de 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java @@ -15,8 +15,8 @@ import org.elasticsearch.compute.test.RandomBlock; import org.elasticsearch.core.Releasables; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import java.util.List; @@ -206,19 +206,19 @@ private byte[] randomByteEmbedding(int dimension) { } private static InferenceAction.Response createFloatEmbeddingResponse(float[] embedding) { - var embeddingResult = new TextEmbeddingFloatResults.Embedding(embedding); - var textEmbeddingResults = new TextEmbeddingFloatResults(List.of(embeddingResult)); + var embeddingResult = new DenseEmbeddingFloatResults.Embedding(embedding); + var textEmbeddingResults = new DenseEmbeddingFloatResults(List.of(embeddingResult)); return new InferenceAction.Response(textEmbeddingResults); } private static InferenceAction.Response createByteEmbeddingResponse(byte[] embedding) { - var embeddingResult = new TextEmbeddingByteResults.Embedding(embedding); - var textEmbeddingResults = new TextEmbeddingByteResults(List.of(embeddingResult)); + var embeddingResult = new DenseEmbeddingByteResults.Embedding(embedding); + var textEmbeddingResults = new DenseEmbeddingByteResults(List.of(embeddingResult)); return new InferenceAction.Response(textEmbeddingResults); } private static InferenceAction.Response createEmptyFloatEmbeddingResponse() { - var textEmbeddingResults = new TextEmbeddingFloatResults(List.of()); + var textEmbeddingResults = new DenseEmbeddingFloatResults(List.of()); return new InferenceAction.Response(textEmbeddingResults); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java index 6ff9a90b70b16..06441be9e7148 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java @@ -13,7 +13,7 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.esql.inference.InferenceOperatorTestCase; import org.hamcrest.Matcher; import org.junit.Before; @@ -23,7 +23,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; -public class TextEmbeddingOperatorTests extends InferenceOperatorTestCase { +public class TextEmbeddingOperatorTests extends InferenceOperatorTestCase { private static final String SIMPLE_INFERENCE_ID = "test_text_embedding"; private static final int EMBEDDING_DIMENSION = 384; // Common embedding dimension @@ -89,15 +89,15 @@ private void assertTextEmbeddingResults(Page inputPage, Page resultPage) { } @Override - protected TextEmbeddingFloatResults mockInferenceResult(InferenceAction.Request request) { + protected DenseEmbeddingFloatResults mockInferenceResult(InferenceAction.Request request) { // For text embedding, we expect one input text per request String inputText = request.getInput().get(0); // Generate a deterministic mock embedding based on the input text float[] mockEmbedding = generateMockEmbedding(inputText, EMBEDDING_DIMENSION); - var embeddingResult = new TextEmbeddingFloatResults.Embedding(mockEmbedding); - return new TextEmbeddingFloatResults(List.of(embeddingResult)); + var embeddingResult = new DenseEmbeddingFloatResults.Embedding(mockEmbedding); + return new DenseEmbeddingFloatResults(List.of(embeddingResult)); } @Override diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 44c9d0463cd05..0234630e70fb2 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -120,15 +120,15 @@ public void start(Model model, TimeValue timeout, ActionListener listen public void close() throws IOException {} protected List chunkInputs(ChunkInferenceInput input) { - ChunkingSettings chunkingSettings = input.chunkingSettings(); - String inputText = input.input(); + ChunkingSettings chunkingSettings = input.getChunkingSettings(); + String inputText = input.getInput(); if (chunkingSettings == null) { return List.of(new ChunkedInput(inputText, 0, inputText.length())); } List chunkedInputs = new ArrayList<>(); if (chunkingSettings.getChunkingStrategy() == ChunkingStrategy.NONE) { - var offsets = NoopChunker.INSTANCE.chunk(input.input(), chunkingSettings); + var offsets = NoopChunker.INSTANCE.chunk(input.getInput(), chunkingSettings); List ret = new ArrayList<>(); for (var offset : offsets) { ret.add(new ChunkedInput(inputText.substring(offset.start(), offset.end()), offset.start(), offset.end())); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java index 728c39b634bd0..c4f9349982859 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java @@ -103,7 +103,8 @@ public void infer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener listener + ActionListener listener, + List imageUrls ) { switch (model.getConfigurations().getTaskType()) { case COMPLETION -> listener.onResponse(makeChatCompletionResults(input)); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 051b6dbf3e8fa..fc3d2a0914305 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -36,7 +36,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -117,7 +117,8 @@ public void infer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener listener + ActionListener listener, + List imageUrls ) { switch (model.getConfigurations().getTaskType()) { case ANY, TEXT_EMBEDDING -> { @@ -167,22 +168,22 @@ public void chunkedInfer( } } - private TextEmbeddingFloatResults makeResults(List input, ServiceSettings serviceSettings) { - List embeddings = new ArrayList<>(); + private DenseEmbeddingFloatResults makeResults(List input, ServiceSettings serviceSettings) { + List embeddings = new ArrayList<>(); for (String inputString : input) { List floatEmbeddings = generateEmbedding(inputString, serviceSettings.dimensions(), serviceSettings.elementType()); - embeddings.add(TextEmbeddingFloatResults.Embedding.of(floatEmbeddings)); + embeddings.add(DenseEmbeddingFloatResults.Embedding.of(floatEmbeddings)); } - return new TextEmbeddingFloatResults(embeddings); + return new DenseEmbeddingFloatResults(embeddings); } private List makeChunkedResults(List inputs, ServiceSettings serviceSettings) { var results = new ArrayList(); for (ChunkInferenceInput input : inputs) { List chunkedInput = chunkInputs(input); - List chunks = chunkedInput.stream() + List chunks = chunkedInput.stream() .map( - c -> new TextEmbeddingFloatResults.Chunk( + c -> new DenseEmbeddingFloatResults.Chunk( makeResults(List.of(c.input()), serviceSettings).embeddings().get(0), new ChunkedInference.TextOffset(c.startOffset(), c.endOffset()) ) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index 962fc9e1ee818..f1f62e58e64c5 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -120,7 +120,8 @@ public void infer( Map taskSettingsMap, InputType inputType, TimeValue timeout, - ActionListener listener + ActionListener listener, + List imageUrls ) { TaskSettings taskSettings = model.getTaskSettings().updatedTaskSettings(taskSettingsMap); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 86dcb56fa369d..7ca970107a3a1 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -111,7 +111,8 @@ public void infer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener listener + ActionListener listener, + List imageUrls ) { switch (model.getConfigurations().getTaskType()) { case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeResults(input)); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 28a191a1bbfac..5592667a13ae1 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -35,9 +35,9 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.DequeUtils; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import java.io.IOException; import java.util.ArrayList; @@ -125,7 +125,8 @@ public void infer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener listener + ActionListener listener, + List imageUrls ) { switch (model.getConfigurations().getTaskType()) { case COMPLETION -> listener.onResponse(makeChatCompletionResults(input)); @@ -189,16 +190,16 @@ public void cancel() {} }); } - private TextEmbeddingFloatResults makeTextEmbeddingResults(List input) { - var embeddings = new ArrayList(); + private DenseEmbeddingFloatResults makeTextEmbeddingResults(List input) { + var embeddings = new ArrayList(); for (int i = 0; i < input.size(); i++) { var values = new float[5]; for (int j = 0; j < 5; j++) { values[j] = random.nextFloat(); } - embeddings.add(new TextEmbeddingFloatResults.Embedding(values)); + embeddings.add(new DenseEmbeddingFloatResults.Embedding(values)); } - return new TextEmbeddingFloatResults(embeddings); + return new DenseEmbeddingFloatResults(embeddings); } private InferenceServiceResults.Result completionChunk(String delta) { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 92eea9599ec5d..dbe89d7048a06 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -652,7 +652,7 @@ private static Model buildModelWithUnknownField(String inferenceEntityId) { private static ServiceSettings createServiceSettings(TaskType taskType) { return switch (taskType) { - case TEXT_EMBEDDING -> new TestModel.TestServiceSettings( + case TEXT_EMBEDDING, IMAGE_EMBEDDING, MULTIMODAL_EMBEDDING -> new TestModel.TestServiceSettings( "model", randomIntBetween(2, 100), randomFrom(SimilarityMeasure.values()), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index e7008c2292def..9e32effd506d6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -21,13 +21,13 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.chunking.NoneChunkingSettings; import org.elasticsearch.xpack.inference.chunking.RecursiveChunkingSettings; @@ -657,10 +657,14 @@ private static void addInferenceResultsNamedWriteables(List inputs = requests.stream() - .map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings)) + .map(r -> new ChunkInferenceTextInput(r.input, r.chunkingSettings)) .collect(Collectors.toList()); ActionListener> completionListener = new ActionListener<>() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index da071442d6c1b..7b36c66d7a441 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.inference.ChunkInferenceImageInput; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; @@ -44,15 +45,27 @@ public class EmbeddingRequestChunker> { // Visible for testing - record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, String input) { + record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, String input, boolean isImage) { public String chunkText() { return input.substring(chunk.start(), chunk.end()); } } public record BatchRequest(List requests) { - public Supplier> inputs() { - return () -> requests.stream().map(Request::chunkText).collect(Collectors.toList()); + public Supplier> textInputs() { + return () -> requests.stream() + .filter(request -> request.isImage() == false) + .map(Request::chunkText) + .collect(Collectors.toList()); + } + + /** + * Since images are not chunked, no String copying takes place when calling {@link Request#chunkText()}, so the list can + * be returned directly without using a {@link Supplier} + * @return a list of Strings representing image URLs + */ + public List imageUrlInputs() { + return requests.stream().filter(Request::isImage).map(Request::chunkText).collect(Collectors.toList()); } } @@ -98,7 +111,7 @@ public EmbeddingRequestChunker( } Map chunkers = inputs.stream() - .map(ChunkInferenceInput::chunkingSettings) + .map(ChunkInferenceInput::getChunkingSettings) .filter(Objects::nonNull) .map(ChunkingSettings::getChunkingStrategy) .distinct() @@ -107,12 +120,21 @@ public EmbeddingRequestChunker( List allRequests = new ArrayList<>(); for (int inputIndex = 0; inputIndex < inputs.size(); inputIndex++) { - ChunkingSettings chunkingSettings = inputs.get(inputIndex).chunkingSettings(); + ChunkInferenceInput chunkInferenceInput = inputs.get(inputIndex); + ChunkingSettings chunkingSettings = chunkInferenceInput.getChunkingSettings(); if (chunkingSettings == null) { chunkingSettings = defaultChunkingSettings; } - Chunker chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker); - String inputString = inputs.get(inputIndex).input(); + Chunker chunker; + boolean isImage = chunkInferenceInput instanceof ChunkInferenceImageInput; + if (isImage) { + // Do not chunk image URLs + chunker = NoopChunker.INSTANCE; + chunkingSettings = NoneChunkingSettings.INSTANCE; + } else { + chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker); + } + String inputString = chunkInferenceInput.getInput(); List chunks = chunker.chunk(inputString, chunkingSettings); int resultCount = Math.min(chunks.size(), MAX_CHUNKS); resultEmbeddings.add(new AtomicReferenceArray<>(resultCount)); @@ -130,7 +152,7 @@ public EmbeddingRequestChunker( } else { resultOffsetEnds.getLast().set(targetChunkIndex, chunks.get(chunkIndex).end()); } - allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputString)); + allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputString, isImage)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java index f9fd3a2011ee0..ff13f65e92fe2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java @@ -18,24 +18,26 @@ public class EmbeddingsInput extends InferenceInputs { private final Supplier> inputListSupplier; private final InputType inputType; + private final List imageUrls; private final AtomicBoolean supplierInvoked = new AtomicBoolean(); public EmbeddingsInput(List input, @Nullable InputType inputType) { - this(() -> input, inputType, false); + this(() -> input, inputType, false, null); } - public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { - this(() -> input, inputType, stream); + public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream, List imageUrls) { + this(() -> input, inputType, stream, imageUrls); } - public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { - this(inputSupplier, inputType, false); + public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType, List imageUrls) { + this(inputSupplier, inputType, false, imageUrls); } - private EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType, boolean stream) { + private EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType, boolean stream, List imageUrls) { super(stream); this.inputListSupplier = Objects.requireNonNull(inputSupplier); this.inputType = inputType; + this.imageUrls = imageUrls; } /** @@ -55,6 +57,10 @@ public InputType getInputType() { return this.inputType; } + public List getImageUrls() { + return imageUrls; + } + @Override public boolean isSingleInput() { // We can't measure the size of the input list without executing diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java index a96ebc0048f70..014a5693e436e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java @@ -13,7 +13,7 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -57,7 +57,7 @@ public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntity { * * */ - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) { return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults(); } @@ -81,9 +81,9 @@ public record EmbeddingFloatResult(List embeddingResu }, new ParseField("data"), org.elasticsearch.xcontent.ObjectParser.ValueType.OBJECT_ARRAY); } - public TextEmbeddingFloatResults toTextEmbeddingFloatResults() { - return new TextEmbeddingFloatResults( - embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList() + public DenseEmbeddingFloatResults toTextEmbeddingFloatResults() { + return new DenseEmbeddingFloatResults( + embeddingResults.stream().map(entry -> DenseEmbeddingFloatResults.Embedding.of(entry.embedding)).toList() ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index b7c01ce817b32..62c6831c6bf78 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -432,7 +432,7 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { private void validateServiceSettings(MinimalServiceSettings settings, MinimalServiceSettings resolved) { switch (settings.taskType()) { - case SPARSE_EMBEDDING, TEXT_EMBEDDING -> { + case SPARSE_EMBEDDING, TEXT_EMBEDDING, IMAGE_EMBEDDING, MULTIMODAL_EMBEDDING -> { } default -> throw new IllegalArgumentException( "Wrong [" diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index f0b25bd427b69..99db55c38baee 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -71,11 +71,16 @@ public void infer( Map taskSettings, InputType inputType, @Nullable TimeValue timeout, - ActionListener listener + ActionListener listener, + @Nullable List imageUrls ) { timeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService); init(); - var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream); + // TODO: combine images and text inputs into one list? + if (input == null) { + input = List.of(); + } + var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream, imageUrls); doInfer(model, inferenceInput, taskSettings, timeout, listener); } @@ -87,7 +92,8 @@ private static InferenceInputs createInput( @Nullable String query, @Nullable Boolean returnDocuments, @Nullable Integer topN, - boolean stream + boolean stream, + @Nullable List imageUrls ) { return switch (model.getTaskType()) { case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream); @@ -99,13 +105,13 @@ private static InferenceInputs createInput( } yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream); } - case TEXT_EMBEDDING, SPARSE_EMBEDDING -> { + case TEXT_EMBEDDING, SPARSE_EMBEDDING, IMAGE_EMBEDDING, MULTIMODAL_EMBEDDING -> { ValidationException validationException = new ValidationException(); service.validateInputType(inputType, model, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - yield new EmbeddingsInput(input, inputType, stream); + yield new EmbeddingsInput(input, inputType, stream, imageUrls); } default -> throw new ElasticsearchStatusException( Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index f474850b9f190..bd0cee4e2a3fc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -359,7 +359,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java index 4e73f03e2898b..0a314202c922c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java @@ -9,7 +9,7 @@ import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -70,20 +70,20 @@ public class AlibabaCloudSearchEmbeddingsResponseEntity extends AlibabaCloudSear * * */ - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { return fromResponse(request, response, parser -> { positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = XContentParserUtils.parseList( + List embeddingList = XContentParserUtils.parseList( parser, AlibabaCloudSearchEmbeddingsResponseEntity::parseEmbeddingObject ); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); }); } - private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -95,7 +95,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent // if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array parser.skipChildren(); - return TextEmbeddingFloatResults.Embedding.of(embeddingValues); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValues); } private static float parseEmbeddingList(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 11204018a5523..c63ae66b4dfb4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -167,7 +167,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java index 831bf9938c211..61ec5f0c39790 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java @@ -16,7 +16,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.response.XContentUtils; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockRequest; @@ -48,7 +48,7 @@ public InferenceServiceResults accept(AmazonBedrockRequest request) { throw new ElasticsearchException("unexpected request type [" + request.getClass() + "]"); } - public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) { + public static DenseEmbeddingFloatResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) { var charset = StandardCharsets.UTF_8; var bodyText = String.valueOf(charset.decode(response.body().asByteBuffer())); @@ -63,13 +63,13 @@ public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse respons var embeddingList = parseEmbeddings(jsonParser, provider); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } catch (IOException e) { throw new ElasticsearchException(e); } } - private static List parseEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider) + private static List parseEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider) throws IOException { switch (provider) { case AMAZONTITAN -> { @@ -82,7 +82,7 @@ private static List parseEmbeddings(XConten } } - private static List parseTitanEmbeddings(XContentParser parser) throws IOException { + private static List parseTitanEmbeddings(XContentParser parser) throws IOException { /* Titan response: { @@ -92,11 +92,11 @@ private static List parseTitanEmbeddings(XC */ positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); - var embeddingValues = TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); + var embeddingValues = DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); return List.of(embeddingValues); } - private static List parseCohereEmbeddings(XContentParser parser) throws IOException { + private static List parseCohereEmbeddings(XContentParser parser) throws IOException { /* Cohere response: { @@ -111,7 +111,7 @@ private static List parseCohereEmbeddings(X */ positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = parseList( + List embeddingList = parseList( parser, AmazonBedrockEmbeddingsResponse::parseCohereEmbeddingsListItem ); @@ -119,9 +119,9 @@ private static List parseCohereEmbeddings(X return embeddingList; } - private static TextEmbeddingFloatResults.Embedding parseCohereEmbeddingsListItem(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseCohereEmbeddingsListItem(XContentParser parser) throws IOException { List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); - return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 7578aa702ad7c..8967889440631 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -153,7 +153,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 077e5361dd46f..f946e29a93acb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -304,7 +304,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = azureOpenAiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 2561f198075e2..5f9b998af23f4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -72,7 +72,13 @@ public class CohereService extends SenderService implements RerankingInferenceSe public static final String NAME = "cohere"; private static final String SERVICE_NAME = "Cohere"; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.RERANK); + private static final EnumSet supportedTaskTypes = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.IMAGE_EMBEDDING, + TaskType.MULTIMODAL_EMBEDDING, + TaskType.COMPLETION, + TaskType.RERANK + ); public static final EnumSet VALID_INPUT_TYPE_VALUES = EnumSet.of( InputType.INGEST, @@ -176,13 +182,14 @@ private static CohereModel createModel( ConfigurationParseContext context ) { return switch (taskType) { - case TEXT_EMBEDDING -> new CohereEmbeddingsModel( + case TEXT_EMBEDDING, IMAGE_EMBEDDING, MULTIMODAL_EMBEDDING -> new CohereEmbeddingsModel( inferenceEntityId, serviceSettings, taskSettings, chunkingSettings, secretSettings, - context + context, + taskType ); case RERANK -> new CohereRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context); case COMPLETION -> new CohereCompletionModel(inferenceEntityId, serviceSettings, secretSettings, context); @@ -308,7 +315,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = cohereModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java index 121f0e1e80a96..c7a64e323107d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.cohere.action; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; @@ -44,9 +45,21 @@ */ public class CohereActionCreator implements CohereActionVisitor { - private static final ResponseHandler EMBEDDINGS_HANDLER = new CohereResponseHandler( + private static final ResponseHandler TEXT_EMBEDDINGS_HANDLER = new CohereResponseHandler( "cohere text embedding", - CohereEmbeddingsResponseEntity::fromResponse, + (request, response) -> CohereEmbeddingsResponseEntity.fromResponse(request, response, TaskType.TEXT_EMBEDDING), + false + ); + + private static final ResponseHandler IMAGE_EMBEDDINGS_HANDLER = new CohereResponseHandler( + "cohere image embedding", + (request, response) -> CohereEmbeddingsResponseEntity.fromResponse(request, response, TaskType.IMAGE_EMBEDDING), + false + ); + + private static final ResponseHandler MULTIMODAL_EMBEDDINGS_HANDLER = new CohereResponseHandler( + "cohere multimodal embedding", + (request, response) -> CohereEmbeddingsResponseEntity.fromResponse(request, response, TaskType.MULTIMODAL_EMBEDDING), false ); @@ -82,7 +95,12 @@ public ExecutableAction create(CohereEmbeddingsModel model, Map return switch (overriddenModel.getServiceSettings().getCommonSettings().apiVersion()) { case V1 -> new CohereV1EmbeddingsRequest(inferenceInputs.getInputs(), requestInputType, overriddenModel); - case V2 -> new CohereV2EmbeddingsRequest(inferenceInputs.getInputs(), requestInputType, overriddenModel); + case V2 -> new CohereV2EmbeddingsRequest( + inferenceInputs.getInputs(), + requestInputType, + overriddenModel, + inferenceInputs.getImageUrls() + ); }; }; @@ -90,13 +108,22 @@ public ExecutableAction create(CohereEmbeddingsModel model, Map var requestManager = new GenericRequestManager<>( serviceComponents.threadPool(), model, - EMBEDDINGS_HANDLER, + getResponseHandlerForEmbeddingTaskType(model.getTaskType()), requestCreator, EmbeddingsInput.class ); return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage); } + private static ResponseHandler getResponseHandlerForEmbeddingTaskType(TaskType taskType) { + return switch (taskType) { + case TEXT_EMBEDDING -> TEXT_EMBEDDINGS_HANDLER; + case IMAGE_EMBEDDING -> IMAGE_EMBEDDINGS_HANDLER; + case MULTIMODAL_EMBEDDING -> MULTIMODAL_EMBEDDINGS_HANDLER; + default -> throw new IllegalArgumentException("Invalid TaskType for embeddings action"); + }; + } + @Override public ExecutableAction create(CohereRerankModel model, Map taskSettings) { var overriddenModel = CohereRerankModel.of(model, taskSettings); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java index 525674cc9b2ef..9a3ba04b36fa4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java @@ -34,14 +34,16 @@ public CohereEmbeddingsModel( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secrets, - ConfigurationParseContext context + ConfigurationParseContext context, + TaskType taskType ) { this( inferenceId, CohereEmbeddingsServiceSettings.fromMap(serviceSettings, context), CohereEmbeddingsTaskSettings.fromMap(taskSettings), chunkingSettings, - DefaultSecretSettings.fromMap(secrets) + DefaultSecretSettings.fromMap(secrets), + taskType ); } @@ -51,10 +53,11 @@ public CohereEmbeddingsModel( CohereEmbeddingsServiceSettings serviceSettings, CohereEmbeddingsTaskSettings taskSettings, ChunkingSettings chunkingSettings, - @Nullable DefaultSecretSettings secretSettings + @Nullable DefaultSecretSettings secretSettings, + TaskType taskType ) { super( - new ModelConfigurations(modelId, TaskType.TEXT_EMBEDDING, CohereService.NAME, serviceSettings, taskSettings, chunkingSettings), + new ModelConfigurations(modelId, taskType, CohereService.NAME, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secretSettings), secretSettings, serviceSettings.getCommonSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java index 2d52a8a9dadbb..28ed71dfe2742 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java @@ -38,6 +38,12 @@ public class CohereUtils { public static final String STREAM_FIELD = "stream"; public static final String TEXTS_FIELD = "texts"; public static final String USER_FIELD = "user"; + public static final String INPUTS_FIELD = "inputs"; + public static final String CONTENT_FIELD = "content"; + public static final String CONTENT_TYPE_FIELD = "type"; + public static final String IMAGE_URL_FIELD = "image_url"; + public static final String TEXT_FIELD = "text"; + public static final String URL_FIELD = "url"; public static Header createRequestSourceHeader() { return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java index 6fb8eb5bec7b8..6c2d6dd7c97a9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java @@ -22,14 +22,27 @@ import java.util.Objects; import java.util.Optional; +import static org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils.CONTENT_FIELD; +import static org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils.CONTENT_TYPE_FIELD; +import static org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils.IMAGE_URL_FIELD; +import static org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils.INPUTS_FIELD; +import static org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils.TEXT_FIELD; +import static org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils.URL_FIELD; + public class CohereV2EmbeddingsRequest extends CohereRequest { private final List input; private final InputType inputType; private final CohereEmbeddingsTaskSettings taskSettings; private final CohereEmbeddingType embeddingType; + private final List imageUrls; - public CohereV2EmbeddingsRequest(List input, InputType inputType, CohereEmbeddingsModel embeddingsModel) { + public CohereV2EmbeddingsRequest( + List input, + InputType inputType, + CohereEmbeddingsModel embeddingsModel, + List imageUrls + ) { super( CohereAccount.of(embeddingsModel), embeddingsModel.getInferenceEntityId(), @@ -41,6 +54,7 @@ public CohereV2EmbeddingsRequest(List input, InputType inputType, Cohere this.inputType = Optional.ofNullable(inputType).orElse(InputType.SEARCH); // inputType is required in v2 taskSettings = embeddingsModel.getTaskSettings(); embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType(); + this.imageUrls = imageUrls == null ? List.of() : imageUrls; } @Override @@ -51,7 +65,15 @@ protected List pathSegments() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CohereUtils.TEXTS_FIELD, input); + builder.startArray(INPUTS_FIELD); + for (String anInput : input) { + addInput(builder, anInput, true); + } + for (String url : imageUrls) { + addInput(builder, url, false); + } + builder.endArray(); + builder.field(CohereUtils.MODEL_FIELD, getModelId()); // prefer the root level inputType over task settings input type if (InputType.isSpecified(inputType)) { @@ -66,4 +88,32 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + + private static void addInput(XContentBuilder builder, String anInput, boolean isText) throws IOException { + builder.startObject(); + builder.startArray(CONTENT_FIELD); + if (isText) { + addText(builder, anInput); + } else { + addImageUrl(builder, anInput); + } + builder.endArray(); + builder.endObject(); + } + + private static void addText(XContentBuilder builder, String anInput) throws IOException { + builder.startObject(); + builder.field(CONTENT_TYPE_FIELD, TEXT_FIELD); + builder.field(TEXT_FIELD, anInput); + builder.endObject(); + } + + private static void addImageUrl(XContentBuilder builder, String url) throws IOException { + builder.startObject(); + builder.field(CONTENT_TYPE_FIELD, IMAGE_URL_FIELD); + builder.startObject(IMAGE_URL_FIELD); + builder.field(URL_FIELD, url); + builder.endObject(); + builder.endObject(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java index b4a2e142b3792..0660a9eecdcdf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java @@ -7,17 +7,18 @@ package org.elasticsearch.xpack.inference.services.cohere.response; +import org.elasticsearch.common.CheckedBiFunction; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; -import org.elasticsearch.core.CheckedFunction; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.XContentUtils; @@ -38,14 +39,15 @@ public class CohereEmbeddingsResponseEntity { private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Cohere embeddings response"; - private static final Map> EMBEDDING_PARSERS = Map.of( - toLowerCase(CohereEmbeddingType.FLOAT), - CohereEmbeddingsResponseEntity::parseFloatEmbeddingsArray, - toLowerCase(CohereEmbeddingType.INT8), - CohereEmbeddingsResponseEntity::parseByteEmbeddingsArray, - toLowerCase(CohereEmbeddingType.BINARY), - CohereEmbeddingsResponseEntity::parseBitEmbeddingsArray - ); + private static final Map> EMBEDDING_PARSERS = + Map.of( + toLowerCase(CohereEmbeddingType.FLOAT), + CohereEmbeddingsResponseEntity::parseFloatEmbeddingsArray, + toLowerCase(CohereEmbeddingType.INT8), + CohereEmbeddingsResponseEntity::parseByteEmbeddingsArray, + toLowerCase(CohereEmbeddingType.BINARY), + CohereEmbeddingsResponseEntity::parseBitEmbeddingsArray + ); private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes(); private static String supportedEmbeddingTypes() { @@ -136,7 +138,7 @@ private static String supportedEmbeddingTypes() { * * */ - public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException { + public static InferenceServiceResults fromResponse(Request request, HttpResult response, TaskType taskType) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { @@ -149,10 +151,10 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r token = jsonParser.currentToken(); if (token == XContentParser.Token.START_OBJECT) { - return parseEmbeddingsObject(jsonParser); + return parseEmbeddingsObject(jsonParser, taskType); } else if (token == XContentParser.Token.START_ARRAY) { // if the request did not specify the embedding types then it will default to floats - return parseFloatEmbeddingsArray(jsonParser); + return parseFloatEmbeddingsArray(jsonParser, taskType); } else { throwUnknownToken(token, jsonParser); } @@ -163,7 +165,7 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r } } - private static InferenceServiceResults parseEmbeddingsObject(XContentParser parser) throws IOException { + private static InferenceServiceResults parseEmbeddingsObject(XContentParser parser, TaskType taskType) throws IOException { XContentParser.Token token = parser.nextToken(); while (token != null && token != XContentParser.Token.END_OBJECT) { @@ -171,7 +173,7 @@ private static InferenceServiceResults parseEmbeddingsObject(XContentParser pars var embeddingValueParser = EMBEDDING_PARSERS.get(parser.currentName()); if (embeddingValueParser != null) { parser.nextToken(); - return embeddingValueParser.apply(parser); + return embeddingValueParser.apply(parser, taskType); } } token = parser.nextToken(); @@ -185,24 +187,24 @@ private static InferenceServiceResults parseEmbeddingsObject(XContentParser pars ); } - private static InferenceServiceResults parseBitEmbeddingsArray(XContentParser parser) throws IOException { + private static InferenceServiceResults parseBitEmbeddingsArray(XContentParser parser, TaskType taskType) throws IOException { // Cohere returns array of binary embeddings encoded as bytes with int8 precision so we can reuse the byte parser var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry); - return new TextEmbeddingBitResults(embeddingList); + return new DenseEmbeddingBitResults(embeddingList); } - private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException { + private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser, TaskType taskType) throws IOException { var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry); - return new TextEmbeddingByteResults(embeddingList); + return new DenseEmbeddingByteResults(embeddingList); } - private static TextEmbeddingByteResults.Embedding parseByteArrayEntry(XContentParser parser) throws IOException { + private static DenseEmbeddingByteResults.Embedding parseByteArrayEntry(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); List embeddingValuesList = parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry); - return TextEmbeddingByteResults.Embedding.of(embeddingValuesList); + return DenseEmbeddingByteResults.Embedding.of(embeddingValuesList); } private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException { @@ -220,16 +222,16 @@ private static void checkByteBounds(short value) { } } - private static InferenceServiceResults parseFloatEmbeddingsArray(XContentParser parser) throws IOException { + private static InferenceServiceResults parseFloatEmbeddingsArray(XContentParser parser, TaskType taskType) throws IOException { var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseFloatArrayEntry); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList, taskType.toString()); } - private static TextEmbeddingFloatResults.Embedding parseFloatArrayEntry(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseFloatArrayEntry(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); - return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); } private CohereEmbeddingsResponseEntity() {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index fd29b02012185..2e7257e2fa333 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -317,7 +317,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = new SenderExecutableAction(getSender(), manager, failedToSendRequestErrorMessage); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java index f665c0be81511..e1146c9782a48 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java @@ -14,9 +14,9 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.common.MapPathExtractor; import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType; @@ -174,7 +174,7 @@ private interface EmbeddingConverter { private static class FloatEmbeddings implements EmbeddingConverter { - private final List embeddings; + private final List embeddings; FloatEmbeddings() { this.embeddings = new ArrayList<>(); @@ -182,17 +182,17 @@ private static class FloatEmbeddings implements EmbeddingConverter { public void toEmbedding(Object entry, String fieldName) { var embeddingsAsListFloats = convertToListOfFloats(entry, fieldName); - embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingsAsListFloats)); + embeddings.add(DenseEmbeddingFloatResults.Embedding.of(embeddingsAsListFloats)); } - public TextEmbeddingFloatResults getResults() { - return new TextEmbeddingFloatResults(embeddings); + public DenseEmbeddingFloatResults getResults() { + return new DenseEmbeddingFloatResults(embeddings); } } private static class ByteEmbeddings implements EmbeddingConverter { - private final List embeddings; + private final List embeddings; ByteEmbeddings() { this.embeddings = new ArrayList<>(); @@ -200,17 +200,17 @@ private static class ByteEmbeddings implements EmbeddingConverter { public void toEmbedding(Object entry, String fieldName) { var convertedEmbeddings = convertToListOfBytes(entry, fieldName); - this.embeddings.add(TextEmbeddingByteResults.Embedding.of(convertedEmbeddings)); + this.embeddings.add(DenseEmbeddingByteResults.Embedding.of(convertedEmbeddings)); } - public TextEmbeddingByteResults getResults() { - return new TextEmbeddingByteResults(embeddings); + public DenseEmbeddingByteResults getResults() { + return new DenseEmbeddingByteResults(embeddings); } } private static class BitEmbeddings implements EmbeddingConverter { - private final List embeddings; + private final List embeddings; BitEmbeddings() { this.embeddings = new ArrayList<>(); @@ -218,11 +218,11 @@ private static class BitEmbeddings implements EmbeddingConverter { public void toEmbedding(Object entry, String fieldName) { var convertedEmbeddings = convertToListOfBits(entry, fieldName); - this.embeddings.add(TextEmbeddingByteResults.Embedding.of(convertedEmbeddings)); + this.embeddings.add(DenseEmbeddingByteResults.Embedding.of(convertedEmbeddings)); } - public TextEmbeddingBitResults getResults() { - return new TextEmbeddingBitResults(embeddings); + public DenseEmbeddingBitResults getResults() { + return new DenseEmbeddingBitResults(embeddings); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index cc871da8eb860..2026496189b66 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -388,7 +388,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = denseTextEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } return; @@ -405,7 +409,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = sparseTextEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } return; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 89258d5716e8e..8092cf02fe868 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -38,9 +38,9 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.XPackSettings; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; @@ -612,7 +612,8 @@ public void infer( Map taskSettings, InputType inputType, @Nullable TimeValue timeout, - ActionListener listener + ActionListener listener, + List imageUrls ) { timeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, getClusterService()); if (model instanceof ElasticsearchInternalModel esModel) { @@ -647,7 +648,7 @@ public void inferTextEmbedding( ); ActionListener mlResultsListener = listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(TextEmbeddingFloatResults.of(inferenceResult.getInferenceResults())) + (l, inferenceResult) -> l.onResponse(DenseEmbeddingFloatResults.of(inferenceResult.getInferenceResults())) ); var maybeDeployListener = mlResultsListener.delegateResponse( @@ -757,11 +758,11 @@ private static void translateToChunkedResult( ActionListener chunkPartListener ) { if (taskType == TaskType.TEXT_EMBEDDING) { - var translated = new ArrayList(); + var translated = new ArrayList(); for (var inferenceResult : inferenceResults) { if (inferenceResult instanceof MlTextEmbeddingResults mlTextEmbeddingResult) { - translated.add(new TextEmbeddingFloatResults.Embedding(mlTextEmbeddingResult.getInferenceAsFloat())); + translated.add(new DenseEmbeddingFloatResults.Embedding(mlTextEmbeddingResult.getInferenceAsFloat())); } else if (inferenceResult instanceof ErrorInferenceResults error) { chunkPartListener.onFailure(error.getException()); return; @@ -772,7 +773,7 @@ private static void translateToChunkedResult( return; } } - chunkPartListener.onResponse(new TextEmbeddingFloatResults(translated)); + chunkPartListener.onResponse(new DenseEmbeddingFloatResults(translated)); } else { // sparse var translated = new ArrayList(); @@ -1125,7 +1126,7 @@ private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAft var inferenceRequest = buildInferenceRequest( esModel.mlNodeDeploymentId(), EmptyConfigUpdate.INSTANCE, - batch.batch().inputs().get(), + batch.batch().textInputs().get(), inputType, timeout ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 97bd2502d25b6..8ed91b68cf201 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -366,7 +366,13 @@ protected void doChunkedInfer( ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { - doInfer(model, new EmbeddingsInput(request.batch().inputs(), inputType), taskSettings, timeout, request.listener()); + doInfer( + model, + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + taskSettings, + timeout, + request.listener() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java index 499fe9ae0c6c7..67527698bf02c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java @@ -12,7 +12,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.XContentUtils; @@ -70,7 +70,7 @@ public class GoogleAiStudioEmbeddingsResponseEntity { * */ - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { @@ -81,16 +81,16 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult positionParserAtTokenAfterField(jsonParser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = parseList( + List embeddingList = parseList( jsonParser, GoogleAiStudioEmbeddingsResponseEntity::parseEmbeddingObject ); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } } - private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "values", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -99,7 +99,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent // parse and discard the rest of the object consumeUntilObjectEnd(parser); - return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); } private GoogleAiStudioEmbeddingsResponseEntity() {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 41678689e8b9d..048601df4428b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -293,7 +293,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = googleVertexAiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java index b4038e42c62cb..94272815e8db2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java @@ -13,7 +13,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -64,7 +64,7 @@ public class GoogleVertexAiEmbeddingsResponseEntity { * */ - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { @@ -75,16 +75,16 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult positionParserAtTokenAfterField(jsonParser, "predictions", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = parseList( + List embeddingList = parseList( jsonParser, GoogleVertexAiEmbeddingsResponseEntity::parseEmbeddingObject ); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } } - private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -99,7 +99,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent consumeUntilObjectEnd(parser); consumeUntilObjectEnd(parser); - return TextEmbeddingFloatResults.Embedding.of(embeddingValueList); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValueList); } private static float parseEmbeddingList(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index d0a98d8252923..e22bd358a4910 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -174,7 +174,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = huggingFaceModel.accept(actionCreator); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 775a4e90ae034..d2123b5aa499d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -28,9 +28,9 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -115,20 +115,14 @@ protected void doChunkedInfer( ); // TODO chunking sparse embeddings not implemented - doInfer( - model, - new EmbeddingsInput(inputs.stream().map(ChunkInferenceInput::input).toList(), inputType), - taskSettings, - timeout, - inferListener - ); + doInfer(model, new EmbeddingsInput(ChunkInferenceInput.inputs(inputs), inputType), taskSettings, timeout, inferListener); } private static List translateToChunkedResults( List inputs, InferenceServiceResults inferenceResults ) { - if (inferenceResults instanceof TextEmbeddingFloatResults textEmbeddingResults) { + if (inferenceResults instanceof DenseEmbeddingFloatResults textEmbeddingResults) { validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs), textEmbeddingResults.embeddings().size()); var results = new ArrayList(inputs.size()); @@ -139,7 +133,7 @@ private static List translateToChunkedResults( List.of( new EmbeddingResults.Chunk( textEmbeddingResults.embeddings().get(i), - new ChunkedInference.TextOffset(0, inputs.get(i).input().length()) + new ChunkedInference.TextOffset(0, inputs.get(i).getInput().length()) ) ) ) @@ -154,7 +148,7 @@ private static List translateToChunkedResults( } else { String expectedClasses = Strings.format( "One of [%s,%s]", - TextEmbeddingFloatResults.class.getSimpleName(), + DenseEmbeddingFloatResults.class.getSimpleName(), SparseEmbeddingResults.class.getSimpleName() ); throw createInvalidChunkedResultException(expectedClasses, inferenceResults.getWriteableName()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java index baf1e884108fb..126d5b03fcde4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java @@ -12,7 +12,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.XContentUtils; @@ -33,7 +33,7 @@ public class HuggingFaceEmbeddingsResponseEntity { * Parse the response from hugging face. The known formats are an array of arrays and object with an {@code embeddings} field containing * an array of arrays. */ - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { @@ -91,13 +91,13 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult * sentence-transformers/all-MiniLM-L6-v2 * sentence-transformers/all-MiniLM-L12-v2 */ - private static TextEmbeddingFloatResults parseArrayFormat(XContentParser parser) throws IOException { - List embeddingList = parseList( + private static DenseEmbeddingFloatResults parseArrayFormat(XContentParser parser) throws IOException { + List embeddingList = parseList( parser, HuggingFaceEmbeddingsResponseEntity::parseEmbeddingEntry ); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } /** @@ -136,22 +136,22 @@ private static TextEmbeddingFloatResults parseArrayFormat(XContentParser parser) * intfloat/multilingual-e5-small * sentence-transformers/all-mpnet-base-v2 */ - private static TextEmbeddingFloatResults parseObjectFormat(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults parseObjectFormat(XContentParser parser) throws IOException { positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = parseList( + List embeddingList = parseList( parser, HuggingFaceEmbeddingsResponseEntity::parseEmbeddingEntry ); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } - private static TextEmbeddingFloatResults.Embedding parseEmbeddingEntry(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseEmbeddingEntry(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); - return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); } private HuggingFaceEmbeddingsResponseEntity() {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 8cdc8cd182425..bfa47a3e1da9c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -370,7 +370,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = ibmWatsonxModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java index 4fda9d5661a2c..d12f44932ed6e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java @@ -12,7 +12,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.XContentUtils; @@ -30,7 +30,7 @@ public class IbmWatsonxEmbeddingsResponseEntity { private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in IBM watsonx embeddings response"; - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { @@ -41,16 +41,16 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = parseList( + List embeddingList = parseList( jsonParser, IbmWatsonxEmbeddingsResponseEntity::parseEmbeddingObject ); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } } - private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -59,7 +59,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent // parse and discard the rest of the object consumeUntilObjectEnd(parser); - return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); } private IbmWatsonxEmbeddingsResponseEntity() {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index f6bd954617b76..b8b06ed3f60eb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -299,7 +299,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = jinaaiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java index 8eee003accba0..f9c80d2593642 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java @@ -15,9 +15,9 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.XContentUtils; @@ -126,15 +126,15 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r } private static InferenceServiceResults parseFloatDataObject(XContentParser jsonParser) throws IOException { - List embeddingList = parseList( + List embeddingList = parseList( jsonParser, JinaAIEmbeddingsResponseEntity::parseFloatEmbeddingObject ); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } - private static TextEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -143,19 +143,19 @@ private static TextEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XCo // parse and discard the rest of the object consumeUntilObjectEnd(parser); - return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); } private static InferenceServiceResults parseBitDataObject(XContentParser jsonParser) throws IOException { - List embeddingList = parseList( + List embeddingList = parseList( jsonParser, JinaAIEmbeddingsResponseEntity::parseBitEmbeddingObject ); - return new TextEmbeddingBitResults(embeddingList); + return new DenseEmbeddingBitResults(embeddingList); } - private static TextEmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException { + private static DenseEmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -164,7 +164,7 @@ private static TextEmbeddingByteResults.Embedding parseBitEmbeddingObject(XConte // parse and discard the rest of the object consumeUntilObjectEnd(parser); - return TextEmbeddingByteResults.Embedding.of(embeddingList); + return DenseEmbeddingByteResults.Embedding.of(embeddingList); } private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java index a74f3202e5fb4..11ee5bf99e696 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java @@ -208,7 +208,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = llamaModel.accept(actionCreator); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index b114aa8081b9c..707f88f3738cd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -168,7 +168,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = mistralEmbeddingsModel.accept(actionCreator); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index ae49f5dcef13b..67513c88c011b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -359,7 +359,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = openAiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java index b8130545a711d..d5e8fd726cee8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java @@ -12,7 +12,7 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -65,7 +65,7 @@ public class OpenAiEmbeddingsResponseEntity { * * */ - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) { return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults(); } @@ -83,9 +83,9 @@ public record EmbeddingFloatResult(List embeddingResu PARSER.declareObjectArray(constructorArg(), EmbeddingFloatResultEntry.PARSER::apply, new ParseField("data")); } - public TextEmbeddingFloatResults toTextEmbeddingFloatResults() { - return new TextEmbeddingFloatResults( - embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList() + public DenseEmbeddingFloatResults toTextEmbeddingFloatResults() { + return new DenseEmbeddingFloatResults( + embeddingResults.stream().map(entry -> DenseEmbeddingFloatResults.Embedding.of(entry.embedding)).toList() ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java index 676a1edec126b..9d1fd3247ce77 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java @@ -156,7 +156,8 @@ public void infer( Map taskSettings, InputType inputType, @Nullable TimeValue timeout, - ActionListener listener + ActionListener listener, + @Nullable List imageUrls ) { if (model instanceof SageMakerModel == false) { listener.onFailure(createInvalidModelException(model)); @@ -287,12 +288,13 @@ public void chunkedInfer( query, null, // no return docs while chunking? null, // no topN while chunking? - request.batch().inputs().get(), + request.batch().textInputs().get(), false, // we never stream when chunking null, // since we pass sageMakerModel as the model, we already overwrote the model with the task settings inputType, timeout, - ActionListener.runAfter(request.listener(), () -> l.onResponse(null)) + ActionListener.runAfter(request.listener(), () -> l.onResponse(null)), + null ) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java index a5fd194f12109..2a53265b0bd97 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java @@ -24,10 +24,10 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema; @@ -92,7 +92,7 @@ public Stream namedWriteables() { } @Override - public TextEmbeddingResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { + public DenseEmbeddingResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) { return switch (model.apiServiceSettings().elementType()) { case BIT -> TextEmbeddingBinary.PARSER.apply(p, null); @@ -120,12 +120,14 @@ public TextEmbeddingResults responseBody(SageMakerModel model, InvokeEndpoint * } */ private static class TextEmbeddingBinary { - private static final ParseField TEXT_EMBEDDING_BITS = new ParseField(TextEmbeddingBitResults.TEXT_EMBEDDING_BITS); + private static final ParseField TEXT_EMBEDDING_BITS = new ParseField( + DenseEmbeddingBitResults.getArrayNameFromTaskName(DenseEmbeddingResults.TEXT_EMBEDDING) + ); @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - TextEmbeddingBitResults.class.getSimpleName(), + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + DenseEmbeddingBitResults.class.getSimpleName(), IGNORE_UNKNOWN_FIELDS, - args -> new TextEmbeddingBitResults((List) args[0]) + args -> new DenseEmbeddingBitResults((List) args[0]) ); static { @@ -153,18 +155,18 @@ private static class TextEmbeddingBinary { private static class TextEmbeddingBytes { private static final ParseField TEXT_EMBEDDING_BYTES = new ParseField("text_embedding_bytes"); @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - TextEmbeddingByteResults.class.getSimpleName(), + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + DenseEmbeddingByteResults.class.getSimpleName(), IGNORE_UNKNOWN_FIELDS, - args -> new TextEmbeddingByteResults((List) args[0]) + args -> new DenseEmbeddingByteResults((List) args[0]) ); @SuppressWarnings("unchecked") - private static final ConstructingObjectParser BYTE_PARSER = + private static final ConstructingObjectParser BYTE_PARSER = new ConstructingObjectParser<>( - TextEmbeddingByteResults.Embedding.class.getSimpleName(), + DenseEmbeddingByteResults.Embedding.class.getSimpleName(), IGNORE_UNKNOWN_FIELDS, - args -> TextEmbeddingByteResults.Embedding.of((List) args[0]) + args -> DenseEmbeddingByteResults.Embedding.of((List) args[0]) ); static { @@ -199,18 +201,18 @@ private static class TextEmbeddingBytes { private static class TextEmbeddingFloat { private static final ParseField TEXT_EMBEDDING_FLOAT = new ParseField("text_embedding"); @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - TextEmbeddingByteResults.class.getSimpleName(), + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + DenseEmbeddingByteResults.class.getSimpleName(), IGNORE_UNKNOWN_FIELDS, - args -> new TextEmbeddingFloatResults((List) args[0]) + args -> new DenseEmbeddingFloatResults((List) args[0]) ); @SuppressWarnings("unchecked") - private static final ConstructingObjectParser FLOAT_PARSER = + private static final ConstructingObjectParser FLOAT_PARSER = new ConstructingObjectParser<>( - TextEmbeddingFloatResults.Embedding.class.getSimpleName(), + DenseEmbeddingFloatResults.Embedding.class.getSimpleName(), IGNORE_UNKNOWN_FIELDS, - args -> TextEmbeddingFloatResults.Embedding.of((List) args[0]) + args -> DenseEmbeddingFloatResults.Embedding.of((List) args[0]) ); static { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java index 6fcbd309551e3..67ee6264f944e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java @@ -26,7 +26,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.json.JsonXContent; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; @@ -117,7 +117,7 @@ public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest req } @Override - public TextEmbeddingFloatResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { + public DenseEmbeddingFloatResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) { return OpenAiEmbeddingsResponseEntity.EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/CustomServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/CustomServiceIntegrationValidator.java index aee1ed3ec4ebc..e3ee9410aca7f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/CustomServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/CustomServiceIntegrationValidator.java @@ -62,7 +62,8 @@ public void validate(InferenceService service, Model model, TimeValue timeout, A e ) ) - ) + ), + null ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java index fa0e1b3e590a4..6078531eec3f9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java @@ -17,8 +17,8 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandEmbeddingModel; public class ElasticsearchInternalServiceModelValidator implements ModelValidator { @@ -54,7 +54,7 @@ public void validate(InferenceService service, Model model, TimeValue timeout, A } private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) { - if (results instanceof TextEmbeddingResults embeddingResults) { + if (results instanceof DenseEmbeddingResults embeddingResults) { var serviceSettings = model.getServiceSettings(); var dimensions = serviceSettings.dimensions(); int embeddingSize = getEmbeddingSize(embeddingResults); @@ -79,7 +79,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi throw new ElasticsearchStatusException( "Validation call did not return expected results type." + "Expected a result of type [" - + TextEmbeddingFloatResults.NAME + + DenseEmbeddingFloatResults.NAME + "] got [" + (results == null ? "null" : results.getWriteableName()) + "]", @@ -88,7 +88,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi } } - private int getEmbeddingSize(TextEmbeddingResults embeddingResults) { + private int getEmbeddingSize(DenseEmbeddingResults embeddingResults) { int embeddingSize; try { embeddingSize = embeddingResults.getFirstEmbeddingSize(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java index fac9ee5e9c1c1..dda6ca68be613 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java @@ -35,7 +35,7 @@ private static ModelValidator buildModelValidatorForTaskType(TaskType taskType, } switch (taskType) { - case TEXT_EMBEDDING -> { + case TEXT_EMBEDDING, IMAGE_EMBEDDING, MULTIMODAL_EMBEDDING -> { return new TextEmbeddingModelValidator( Objects.requireNonNullElse(validatorFromService, new SimpleServiceIntegrationValidator()) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java index d2cb1925ad7d8..a5f7dc4923f9e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java @@ -23,7 +23,11 @@ import java.util.Map; public class SimpleServiceIntegrationValidator implements ServiceIntegrationValidator { - private static final List TEST_INPUT = List.of("how big"); + private static final List TEST_TEXT_INPUT = List.of("how big"); + // The below data URL represents the base64 encoding of a single black pixel + private static final List TEST_URL_INPUT = List.of( + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVQImWNgYGAAAAAEAAGjChXjAAAAAElFTkSuQmCC" + ); private static final String QUERY = "test query"; @Override @@ -33,7 +37,7 @@ public void validate(InferenceService service, Model model, TimeValue timeout, A model.getTaskType().equals(TaskType.RERANK) ? QUERY : null, null, null, - TEST_INPUT, + TEST_TEXT_INPUT, false, Map.of(), InputType.INTERNAL_INGEST, @@ -57,7 +61,8 @@ public void validate(InferenceService service, Model model, TimeValue timeout, A e ) ); - }) + }), + TEST_URL_INPUT ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java index ce9df7376ebcb..d9c6bec9f46d8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java @@ -16,8 +16,8 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults; public class TextEmbeddingModelValidator implements ModelValidator { @@ -35,7 +35,7 @@ public void validate(InferenceService service, Model model, TimeValue timeout, A } private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) { - if (results instanceof TextEmbeddingResults embeddingResults) { + if (results instanceof DenseEmbeddingResults embeddingResults) { var serviceSettings = model.getServiceSettings(); var dimensions = serviceSettings.dimensions(); int embeddingSize = getEmbeddingSize(embeddingResults); @@ -60,7 +60,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi throw new ElasticsearchStatusException( "Validation call did not return expected results type." + "Expected a result of type [" - + TextEmbeddingFloatResults.NAME + + DenseEmbeddingFloatResults.NAME + "] got [" + (results == null ? "null" : results.getWriteableName()) + "]", @@ -69,7 +69,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi } } - private int getEmbeddingSize(TextEmbeddingResults embeddingResults) { + private int getEmbeddingSize(DenseEmbeddingResults embeddingResults) { int embeddingSize; try { embeddingSize = embeddingResults.getFirstEmbeddingSize(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index c69aeec203e4c..ad62ac3a8c0a4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -330,7 +330,11 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = voyageaiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute( + new EmbeddingsInput(request.batch().textInputs(), inputType, request.batch().imageUrlInputs()), + timeout, + request.listener() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java index f9ba5fd58d21a..61436d509e45a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java @@ -15,9 +15,9 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType; @@ -75,9 +75,9 @@ private static void checkByteBounds(Integer value) { } } - public TextEmbeddingByteResults.Embedding toInferenceByteEmbedding() { + public DenseEmbeddingByteResults.Embedding toInferenceByteEmbedding() { embedding.forEach(EmbeddingInt8ResultEntry::checkByteBounds); - return TextEmbeddingByteResults.Embedding.of(embedding.stream().map(Integer::byteValue).toList()); + return DenseEmbeddingByteResults.Embedding.of(embedding.stream().map(Integer::byteValue).toList()); } } @@ -108,8 +108,8 @@ record EmbeddingFloatResultEntry(Integer index, List embedding) { PARSER.declareFloatArray(constructorArg(), new ParseField("embedding")); } - public TextEmbeddingFloatResults.Embedding toInferenceFloatEmbedding() { - return TextEmbeddingFloatResults.Embedding.of(embedding); + public DenseEmbeddingFloatResults.Embedding toInferenceFloatEmbedding() { + return DenseEmbeddingFloatResults.Embedding.of(embedding); } } @@ -166,22 +166,22 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r if (embeddingType == null || embeddingType == VoyageAIEmbeddingType.FLOAT) { var embeddingResult = EmbeddingFloatResult.PARSER.apply(jsonParser, null); - List embeddingList = embeddingResult.entries.stream() + List embeddingList = embeddingResult.entries.stream() .map(EmbeddingFloatResultEntry::toInferenceFloatEmbedding) .toList(); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } else if (embeddingType == VoyageAIEmbeddingType.INT8) { var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null); - List embeddingList = embeddingResult.entries.stream() + List embeddingList = embeddingResult.entries.stream() .map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding) .toList(); - return new TextEmbeddingByteResults(embeddingList); + return new DenseEmbeddingByteResults(embeddingList); } else if (embeddingType == VoyageAIEmbeddingType.BIT || embeddingType == VoyageAIEmbeddingType.BINARY) { var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null); - List embeddingList = embeddingResult.entries.stream() + List embeddingList = embeddingResult.entries.stream() .map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding) .toList(); - return new TextEmbeddingBitResults(embeddingList); + return new DenseEmbeddingBitResults(embeddingList); } else { throw new IllegalArgumentException( "Illegal embedding_type value: " + embeddingType + ". Supported types are: " + VALID_EMBEDDING_TYPES_STRING diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index cdde8c64eb537..4a0ed2b28f2d4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -404,7 +404,7 @@ protected void mockService( doAnswer(ans -> { listenerAction.accept(ans.getArgument(9)); return null; - }).when(service).infer(any(), any(), anyBoolean(), any(), any(), anyBoolean(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), anyBoolean(), any(), any(), anyBoolean(), any(), any(), any(), any(), any()); doAnswer(ans -> { listenerAction.accept(ans.getArgument(3)); return null; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index a67dc48b47c88..4fb49199b9fc2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -1012,7 +1012,7 @@ private static ShardBulkInferenceActionFilter createFilter( Runnable runnable = () -> { List results = new ArrayList<>(); for (ChunkInferenceInput input : inputs) { - results.add(model.getResults(input.input())); + results.add(model.getResults(input.getInput())); } listener.onResponse(results); }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 411d992adfa3d..fc14aa449816a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -9,15 +9,16 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.WeightedToken; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.hamcrest.Matchers; import java.util.ArrayList; @@ -55,81 +56,83 @@ public void testEmptyInput_NoopChunker() { public void testAnyInput_NoopChunker() { var randomInput = randomAlphaOfLengthBetween(100, 1000); - var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceInput(randomInput)), 10, NoneChunkingSettings.INSTANCE) + var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceTextInput(randomInput)), 10, NoneChunkingSettings.INSTANCE) .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is(randomInput)); + assertThat(batches.get(0).batch().textInputs().get(), hasSize(1)); + assertThat(batches.get(0).batch().textInputs().get().get(0), Matchers.is(randomInput)); } public void testWhitespaceInput_SentenceChunker() { var batches = new EmbeddingRequestChunker<>( - List.of(new ChunkInferenceInput(" ")), + List.of(new ChunkInferenceTextInput(" ")), 10, new SentenceBoundaryChunkingSettings(250, 1) ).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is(" ")); + assertThat(batches.get(0).batch().textInputs().get(), hasSize(1)); + assertThat(batches.get(0).batch().textInputs().get().get(0), Matchers.is(" ")); } public void testBlankInput_WordChunker() { - var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceInput("")), 100, 100, 10).batchRequestsWithListeners( + var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceTextInput("")), 100, 100, 10).batchRequestsWithListeners( testListener() ); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("")); + assertThat(batches.get(0).batch().textInputs().get(), hasSize(1)); + assertThat(batches.get(0).batch().textInputs().get().get(0), Matchers.is("")); } public void testBlankInput_SentenceChunker() { - var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceInput("")), 10, new SentenceBoundaryChunkingSettings(250, 1)) - .batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>( + List.of(new ChunkInferenceTextInput("")), + 10, + new SentenceBoundaryChunkingSettings(250, 1) + ).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("")); + assertThat(batches.get(0).batch().textInputs().get(), hasSize(1)); + assertThat(batches.get(0).batch().textInputs().get().get(0), Matchers.is("")); } public void testInputThatDoesNotChunk_WordChunker() { - var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceInput("ABBAABBA")), 100, 100, 10).batchRequestsWithListeners( - testListener() - ); + var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceTextInput("ABBAABBA")), 100, 100, 10) + .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("ABBAABBA")); + assertThat(batches.get(0).batch().textInputs().get(), hasSize(1)); + assertThat(batches.get(0).batch().textInputs().get().get(0), Matchers.is("ABBAABBA")); } public void testInputThatDoesNotChunk_SentenceChunker() { var batches = new EmbeddingRequestChunker<>( - List.of(new ChunkInferenceInput("ABBAABBA")), + List.of(new ChunkInferenceTextInput("ABBAABBA")), 10, new SentenceBoundaryChunkingSettings(250, 1) ).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("ABBAABBA")); + assertThat(batches.get(0).batch().textInputs().get(), hasSize(1)); + assertThat(batches.get(0).batch().textInputs().get().get(0), Matchers.is("ABBAABBA")); } public void testShortInputsAreSingleBatch() { - ChunkInferenceInput input = new ChunkInferenceInput("one chunk"); + ChunkInferenceInput input = new ChunkInferenceTextInput("one chunk"); var batches = new EmbeddingRequestChunker<>(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), contains(input.input())); + assertThat(batches.get(0).batch().textInputs().get(), contains(input.getInput())); } public void testMultipleShortInputsAreSingleBatch() { List inputs = List.of( - new ChunkInferenceInput("1st small"), - new ChunkInferenceInput("2nd small"), - new ChunkInferenceInput("3rd small") + new ChunkInferenceTextInput("1st small"), + new ChunkInferenceTextInput("2nd small"), + new ChunkInferenceTextInput("3rd small") ); var batches = new EmbeddingRequestChunker<>(inputs, 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); EmbeddingRequestChunker.BatchRequest batch = batches.getFirst().batch(); - assertEquals(batch.inputs().get(), ChunkInferenceInput.inputs(inputs)); + assertEquals(batch.textInputs().get(), ChunkInferenceInput.inputs(inputs)); for (int i = 0; i < inputs.size(); i++) { var request = batch.requests().get(i); - assertThat(request.chunkText(), equalTo(inputs.get(i).input())); + assertThat(request.chunkText(), equalTo(inputs.get(i).getInput())); assertEquals(i, request.inputIndex()); assertEquals(0, request.chunkIndex()); } @@ -141,30 +144,30 @@ public void testManyInputsMakeManyBatches() { var inputs = new ArrayList(); for (int i = 0; i < numInputs; i++) { - inputs.add(new ChunkInferenceInput("input " + i)); + inputs.add(new ChunkInferenceTextInput("input " + i)); } var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(4)); - assertThat(batches.get(0).batch().inputs().get(), hasSize(maxNumInputsPerBatch)); - assertThat(batches.get(1).batch().inputs().get(), hasSize(maxNumInputsPerBatch)); - assertThat(batches.get(2).batch().inputs().get(), hasSize(maxNumInputsPerBatch)); - assertThat(batches.get(3).batch().inputs().get(), hasSize(1)); + assertThat(batches.get(0).batch().textInputs().get(), hasSize(maxNumInputsPerBatch)); + assertThat(batches.get(1).batch().textInputs().get(), hasSize(maxNumInputsPerBatch)); + assertThat(batches.get(2).batch().textInputs().get(), hasSize(maxNumInputsPerBatch)); + assertThat(batches.get(3).batch().textInputs().get(), hasSize(1)); - assertEquals("input 0", batches.get(0).batch().inputs().get().get(0)); - assertEquals("input 9", batches.get(0).batch().inputs().get().get(9)); + assertEquals("input 0", batches.get(0).batch().textInputs().get().get(0)); + assertEquals("input 9", batches.get(0).batch().textInputs().get().get(9)); assertThat( - batches.get(1).batch().inputs().get(), + batches.get(1).batch().textInputs().get(), contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19") ); - assertEquals("input 20", batches.get(2).batch().inputs().get().get(0)); - assertEquals("input 29", batches.get(2).batch().inputs().get().get(9)); - assertThat(batches.get(3).batch().inputs().get(), contains("input 30")); + assertEquals("input 20", batches.get(2).batch().textInputs().get().get(0)); + assertEquals("input 29", batches.get(2).batch().textInputs().get().get(9)); + assertThat(batches.get(3).batch().textInputs().get(), contains("input 30")); List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { EmbeddingRequestChunker.Request request = requests.get(i); - assertThat(request.chunkText(), equalTo(inputs.get(i).input())); + assertThat(request.chunkText(), equalTo(inputs.get(i).getInput())); assertThat(request.inputIndex(), equalTo(i)); assertThat(request.chunkIndex(), equalTo(0)); } @@ -176,31 +179,31 @@ public void testChunkingSettingsProvided() { var inputs = new ArrayList(); for (int i = 0; i < numInputs; i++) { - inputs.add(new ChunkInferenceInput("input " + i)); + inputs.add(new ChunkInferenceTextInput("input " + i)); } var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, ChunkingSettingsTests.createRandomChunkingSettings()) .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(4)); - assertThat(batches.get(0).batch().inputs().get(), hasSize(maxNumInputsPerBatch)); - assertThat(batches.get(1).batch().inputs().get(), hasSize(maxNumInputsPerBatch)); - assertThat(batches.get(2).batch().inputs().get(), hasSize(maxNumInputsPerBatch)); - assertThat(batches.get(3).batch().inputs().get(), hasSize(1)); + assertThat(batches.get(0).batch().textInputs().get(), hasSize(maxNumInputsPerBatch)); + assertThat(batches.get(1).batch().textInputs().get(), hasSize(maxNumInputsPerBatch)); + assertThat(batches.get(2).batch().textInputs().get(), hasSize(maxNumInputsPerBatch)); + assertThat(batches.get(3).batch().textInputs().get(), hasSize(1)); - assertEquals("input 0", batches.get(0).batch().inputs().get().get(0)); - assertEquals("input 9", batches.get(0).batch().inputs().get().get(9)); + assertEquals("input 0", batches.get(0).batch().textInputs().get().get(0)); + assertEquals("input 9", batches.get(0).batch().textInputs().get().get(9)); assertThat( - batches.get(1).batch().inputs().get(), + batches.get(1).batch().textInputs().get(), contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19") ); - assertEquals("input 20", batches.get(2).batch().inputs().get().get(0)); - assertEquals("input 29", batches.get(2).batch().inputs().get().get(9)); - assertThat(batches.get(3).batch().inputs().get(), contains("input 30")); + assertEquals("input 20", batches.get(2).batch().textInputs().get().get(0)); + assertEquals("input 29", batches.get(2).batch().textInputs().get().get(9)); + assertThat(batches.get(3).batch().textInputs().get(), contains("input 30")); List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { EmbeddingRequestChunker.Request request = requests.get(i); - assertThat(request.chunkText(), equalTo(inputs.get(i).input())); + assertThat(request.chunkText(), equalTo(inputs.get(i).getInput())); assertThat(request.inputIndex(), equalTo(i)); assertThat(request.chunkIndex(), equalTo(0)); } @@ -220,10 +223,10 @@ public void testLongInputChunkedOverMultipleBatches() { } List inputs = List.of( - new ChunkInferenceInput("1st small"), - new ChunkInferenceInput(passageBuilder.toString()), - new ChunkInferenceInput("2nd small"), - new ChunkInferenceInput("3rd small") + new ChunkInferenceTextInput("1st small"), + new ChunkInferenceTextInput(passageBuilder.toString()), + new ChunkInferenceTextInput("2nd small"), + new ChunkInferenceTextInput("3rd small") ); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(testListener()); @@ -231,7 +234,7 @@ public void testLongInputChunkedOverMultipleBatches() { assertThat(batches, hasSize(2)); var batch = batches.get(0).batch(); - assertThat(batch.inputs().get(), hasSize(batchSize)); + assertThat(batch.textInputs().get(), hasSize(batchSize)); assertThat(batch.requests(), hasSize(batchSize)); EmbeddingRequestChunker.Request request = batch.requests().get(0); @@ -248,7 +251,7 @@ public void testLongInputChunkedOverMultipleBatches() { } batch = batches.get(1).batch(); - assertThat(batch.inputs().get(), hasSize(4)); + assertThat(batch.textInputs().get(), hasSize(4)); assertThat(batch.requests(), hasSize(4)); for (int requestIndex = 0; requestIndex < 2; requestIndex++) { @@ -281,9 +284,9 @@ public void testVeryLongInput_Sparse() { } List inputs = List.of( - new ChunkInferenceInput("1st small"), - new ChunkInferenceInput(passageBuilder.toString()), - new ChunkInferenceInput("2nd small") + new ChunkInferenceTextInput("1st small"), + new ChunkInferenceTextInput(passageBuilder.toString()), + new ChunkInferenceTextInput("2nd small") ); var finalListener = testListener(); @@ -294,9 +297,9 @@ public void testVeryLongInput_Sparse() { // there are 10002 inference requests, resulting in 2001 batches. assertThat(batches, hasSize(2001)); for (int i = 0; i < 2000; i++) { - assertThat(batches.get(i).batch().inputs().get(), hasSize(5)); + assertThat(batches.get(i).batch().textInputs().get(), hasSize(5)); } - assertThat(batches.get(2000).batch().inputs().get(), hasSize(2)); + assertThat(batches.get(2000).batch().textInputs().get(), hasSize(2)); // Produce inference results for each request, with just the token // "word" and increasing weights. @@ -318,7 +321,7 @@ public void testVeryLongInput_Sparse() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).getInput(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); SparseEmbeddingResults.Embedding embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 1 / 16384f))); @@ -334,8 +337,8 @@ public void testVeryLongInput_Sparse() { // The first merged chunk consists of 20 small chunks (so 400 words) and the max // weight is the weight of the 20th small chunk (so 21/16384). - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); + assertThat(getMatchedText(inputs.get(1).getInput(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1).getInput(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 21 / 16384f))); @@ -343,10 +346,13 @@ public void testVeryLongInput_Sparse() { // The last merged chunk consists of 19 small chunks (so 380 words) and the max // weight is the weight of the 10000th small chunk (so 10001/16384). assertThat( - getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), + getMatchedText(inputs.get(1).getInput(), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ") ); - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedEmbedding.chunks().get(511).offset()), + endsWith(" word199998 word199999") + ); assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 10001 / 16384f))); @@ -356,7 +362,7 @@ public void testVeryLongInput_Sparse() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).getInput(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 10002 / 16384f))); @@ -373,9 +379,9 @@ public void testVeryLongInput_Float() { } List inputs = List.of( - new ChunkInferenceInput("1st small"), - new ChunkInferenceInput(passageBuilder.toString()), - new ChunkInferenceInput("2nd small") + new ChunkInferenceTextInput("1st small"), + new ChunkInferenceTextInput(passageBuilder.toString()), + new ChunkInferenceTextInput("2nd small") ); var finalListener = testListener(); @@ -386,19 +392,19 @@ public void testVeryLongInput_Float() { // there are 10002 inference requests, resulting in 2001 batches. assertThat(batches, hasSize(2001)); for (int i = 0; i < 2000; i++) { - assertThat(batches.get(i).batch().inputs().get(), hasSize(5)); + assertThat(batches.get(i).batch().textInputs().get(), hasSize(5)); } - assertThat(batches.get(2000).batch().inputs().get(), hasSize(2)); + assertThat(batches.get(2000).batch().textInputs().get(), hasSize(2)); // Produce inference results for each request, with increasing weights. float weight = 0f; for (var batch : batches) { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < batch.batch().requests().size(); i++) { weight += 1 / 16384f; - embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { weight })); + embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { weight })); } - batch.listener().onResponse(new TextEmbeddingFloatResults(embeddings)); + batch.listener().onResponse(new DenseEmbeddingFloatResults(embeddings)); } assertNotNull(finalListener.results); @@ -409,9 +415,11 @@ public void testVeryLongInput_Float() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); - TextEmbeddingFloatResults.Embedding embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(getMatchedText(inputs.get(0).getInput(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + DenseEmbeddingFloatResults.Embedding embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks() + .get(0) + .embedding(); assertThat(embedding.values(), equalTo(new float[] { 1 / 16384f })); // The very long passage "word0 word1 ... word199999" is split into 10000 chunks for @@ -425,21 +433,24 @@ public void testVeryLongInput_Float() { // The first merged chunk consists of 20 small chunks (so 400 words) and the weight // is the average of the weights 2/16384 ... 21/16384. - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); - embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(getMatchedText(inputs.get(1).getInput(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1).getInput(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new float[] { (2 + 21) / (2 * 16384f) })); // The last merged chunk consists of 19 small chunks (so 380 words) and the weight // is the average of the weights 9983/16384 ... 10001/16384. assertThat( - getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), + getMatchedText(inputs.get(1).getInput(), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ") ); - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); - assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); - embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedEmbedding.chunks().get(511).offset()), + endsWith(" word199998 word199999") + ); + assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.values(), equalTo(new float[] { (9983 + 10001) / (2 * 16384f) })); // The last input has the token with weight 10002/16384. @@ -447,9 +458,9 @@ public void testVeryLongInput_Float() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); - embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(getMatchedText(inputs.get(2).getInput(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new float[] { 10002 / 16384f })); } @@ -464,9 +475,9 @@ public void testVeryLongInput_Byte() { } List inputs = List.of( - new ChunkInferenceInput("1st small"), - new ChunkInferenceInput(passageBuilder.toString()), - new ChunkInferenceInput("2nd small") + new ChunkInferenceTextInput("1st small"), + new ChunkInferenceTextInput(passageBuilder.toString()), + new ChunkInferenceTextInput("2nd small") ); var finalListener = testListener(); @@ -477,19 +488,19 @@ public void testVeryLongInput_Byte() { // there are 10002 inference requests, resulting in 2001 batches. assertThat(batches, hasSize(2001)); for (int i = 0; i < 2000; i++) { - assertThat(batches.get(i).batch().inputs().get(), hasSize(5)); + assertThat(batches.get(i).batch().textInputs().get(), hasSize(5)); } - assertThat(batches.get(2000).batch().inputs().get(), hasSize(2)); + assertThat(batches.get(2000).batch().textInputs().get(), hasSize(2)); // Produce inference results for each request, with increasing weights. byte weight = 0; for (var batch : batches) { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < batch.batch().requests().size(); i++) { weight += 1; - embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { weight })); + embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { weight })); } - batch.listener().onResponse(new TextEmbeddingByteResults(embeddings)); + batch.listener().onResponse(new DenseEmbeddingByteResults(embeddings)); } assertNotNull(finalListener.results); @@ -500,9 +511,9 @@ public void testVeryLongInput_Byte() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); - TextEmbeddingByteResults.Embedding embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(getMatchedText(inputs.get(0).getInput(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); + DenseEmbeddingByteResults.Embedding embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 1 })); // The very long passage "word0 word1 ... word199999" is split into 10000 chunks for @@ -516,22 +527,25 @@ public void testVeryLongInput_Byte() { // The first merged chunk consists of 20 small chunks (so 400 words) and the weight // is the average of the weights 2 ... 21, so 11.5, which is rounded to 12. - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); - embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(getMatchedText(inputs.get(1).getInput(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1).getInput(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); + embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 12 })); // The last merged chunk consists of 19 small chunks (so 380 words) and the weight // is the average of the weights 9983 ... 10001 modulo 256 (bytes overflowing), so // the average of -1, 0, 1, ... , 17, so 8. assertThat( - getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), + getMatchedText(inputs.get(1).getInput(), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ") ); - assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); - assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); - embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedEmbedding.chunks().get(511).offset()), + endsWith(" word199998 word199999") + ); + assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); + embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 8 })); // The last input has the token with weight 10002 % 256 = 18 @@ -539,9 +553,9 @@ public void testVeryLongInput_Byte() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); - embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(getMatchedText(inputs.get(2).getInput(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); + embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 18 })); } @@ -558,10 +572,10 @@ public void testMergingListener_Float() { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } List inputs = List.of( - new ChunkInferenceInput("1st small"), - new ChunkInferenceInput(passageBuilder.toString()), - new ChunkInferenceInput("2nd small"), - new ChunkInferenceInput("3rd small") + new ChunkInferenceTextInput("1st small"), + new ChunkInferenceTextInput(passageBuilder.toString()), + new ChunkInferenceTextInput("2nd small"), + new ChunkInferenceTextInput("3rd small") ); var finalListener = testListener(); @@ -570,18 +584,18 @@ public void testMergingListener_Float() { // 4 inputs in 2 batches { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < batchSize; i++) { - embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); + embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); } - batches.get(0).listener().onResponse(new TextEmbeddingFloatResults(embeddings)); + batches.get(0).listener().onResponse(new DenseEmbeddingFloatResults(embeddings)); } { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch - embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); + embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); } - batches.get(1).listener().onResponse(new TextEmbeddingFloatResults(embeddings)); + batches.get(1).listener().onResponse(new DenseEmbeddingFloatResults(embeddings)); } assertNotNull(finalListener.results); @@ -591,7 +605,7 @@ public void testMergingListener_Float() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0).input(), chunkedFloatResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).getInput(), chunkedFloatResult.chunks().get(0).offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -599,13 +613,28 @@ public void testMergingListener_Float() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(6)); - assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); assertThat( - getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(5).offset()), + getMatchedText(inputs.get(1).getInput(), chunkedFloatResult.chunks().get(0).offset()), + startsWith("passage_input0 ") + ); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedFloatResult.chunks().get(1).offset()), + startsWith(" passage_input20 ") + ); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedFloatResult.chunks().get(2).offset()), + startsWith(" passage_input40 ") + ); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedFloatResult.chunks().get(3).offset()), + startsWith(" passage_input60 ") + ); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedFloatResult.chunks().get(4).offset()), + startsWith(" passage_input80 ") + ); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedFloatResult.chunks().get(5).offset()), startsWith(" passage_input100 ") ); } @@ -614,14 +643,14 @@ public void testMergingListener_Float() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2).input(), chunkedFloatResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).getInput(), chunkedFloatResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(3).input(), chunkedFloatResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(3).getInput(), chunkedFloatResult.chunks().get(0).offset()), equalTo("3rd small")); } } @@ -638,10 +667,10 @@ public void testMergingListener_Byte() { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } List inputs = List.of( - new ChunkInferenceInput("1st small"), - new ChunkInferenceInput(passageBuilder.toString()), - new ChunkInferenceInput("2nd small"), - new ChunkInferenceInput("3rd small") + new ChunkInferenceTextInput("1st small"), + new ChunkInferenceTextInput(passageBuilder.toString()), + new ChunkInferenceTextInput("2nd small"), + new ChunkInferenceTextInput("3rd small") ); var finalListener = testListener(); @@ -650,18 +679,18 @@ public void testMergingListener_Byte() { // 4 inputs in 2 batches { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < batchSize; i++) { - embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() })); + embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() })); } - batches.get(0).listener().onResponse(new TextEmbeddingByteResults(embeddings)); + batches.get(0).listener().onResponse(new DenseEmbeddingByteResults(embeddings)); } { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch - embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() })); + embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() })); } - batches.get(1).listener().onResponse(new TextEmbeddingByteResults(embeddings)); + batches.get(1).listener().onResponse(new DenseEmbeddingByteResults(embeddings)); } assertNotNull(finalListener.results); @@ -671,7 +700,7 @@ public void testMergingListener_Byte() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).getInput(), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -679,26 +708,41 @@ public void testMergingListener_Byte() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(6)); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); + assertThat(getMatchedText(inputs.get(1).getInput(), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedByteResult.chunks().get(1).offset()), + startsWith(" passage_input20 ") + ); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedByteResult.chunks().get(2).offset()), + startsWith(" passage_input40 ") + ); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedByteResult.chunks().get(3).offset()), + startsWith(" passage_input60 ") + ); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedByteResult.chunks().get(4).offset()), + startsWith(" passage_input80 ") + ); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedByteResult.chunks().get(5).offset()), + startsWith(" passage_input100 ") + ); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).getInput(), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(3).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(3).getInput(), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); } } @@ -715,10 +759,10 @@ public void testMergingListener_Bit() { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } List inputs = List.of( - new ChunkInferenceInput("1st small"), - new ChunkInferenceInput(passageBuilder.toString()), - new ChunkInferenceInput("2nd small"), - new ChunkInferenceInput("3rd small") + new ChunkInferenceTextInput("1st small"), + new ChunkInferenceTextInput(passageBuilder.toString()), + new ChunkInferenceTextInput("2nd small"), + new ChunkInferenceTextInput("3rd small") ); var finalListener = testListener(); @@ -727,18 +771,18 @@ public void testMergingListener_Bit() { // 4 inputs in 2 batches { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < batchSize; i++) { - embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() })); + embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() })); } - batches.get(0).listener().onResponse(new TextEmbeddingBitResults(embeddings)); + batches.get(0).listener().onResponse(new DenseEmbeddingBitResults(embeddings)); } { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch - embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() })); + embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() })); } - batches.get(1).listener().onResponse(new TextEmbeddingBitResults(embeddings)); + batches.get(1).listener().onResponse(new DenseEmbeddingBitResults(embeddings)); } assertNotNull(finalListener.results); @@ -748,7 +792,7 @@ public void testMergingListener_Bit() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).getInput(), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -756,26 +800,41 @@ public void testMergingListener_Bit() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(6)); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); - assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); + assertThat(getMatchedText(inputs.get(1).getInput(), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedByteResult.chunks().get(1).offset()), + startsWith(" passage_input20 ") + ); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedByteResult.chunks().get(2).offset()), + startsWith(" passage_input40 ") + ); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedByteResult.chunks().get(3).offset()), + startsWith(" passage_input60 ") + ); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedByteResult.chunks().get(4).offset()), + startsWith(" passage_input80 ") + ); + assertThat( + getMatchedText(inputs.get(1).getInput(), chunkedByteResult.chunks().get(5).offset()), + startsWith(" passage_input100 ") + ); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).getInput(), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(3).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(3).getInput(), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); } } @@ -792,10 +851,10 @@ public void testMergingListener_Sparse() { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } List inputs = List.of( - new ChunkInferenceInput("1st small"), - new ChunkInferenceInput("2nd small"), - new ChunkInferenceInput("3rd small"), - new ChunkInferenceInput(passageBuilder.toString()) + new ChunkInferenceTextInput("1st small"), + new ChunkInferenceTextInput("2nd small"), + new ChunkInferenceTextInput("3rd small"), + new ChunkInferenceTextInput(passageBuilder.toString()) ); var finalListener = testListener(); @@ -832,21 +891,21 @@ public void testMergingListener_Sparse() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0).input(), chunkedSparseResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).getInput(), chunkedSparseResult.chunks().get(0).offset()), equalTo("1st small")); } { var chunkedResult = finalListener.results.get(1); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(1).input(), chunkedSparseResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(1).getInput(), chunkedSparseResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2).input(), chunkedSparseResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(2).getInput(), chunkedSparseResult.chunks().get(0).offset()), equalTo("3rd small")); } { // this is the large input split in multiple chunks @@ -854,13 +913,16 @@ public void testMergingListener_Sparse() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(9)); // passage is split into 9 chunks, 10 words each - assertThat(getMatchedText(inputs.get(3).input(), chunkedSparseResult.chunks().get(0).offset()), startsWith("passage_input0 ")); assertThat( - getMatchedText(inputs.get(3).input(), chunkedSparseResult.chunks().get(1).offset()), + getMatchedText(inputs.get(3).getInput(), chunkedSparseResult.chunks().get(0).offset()), + startsWith("passage_input0 ") + ); + assertThat( + getMatchedText(inputs.get(3).getInput(), chunkedSparseResult.chunks().get(1).offset()), startsWith(" passage_input10 ") ); assertThat( - getMatchedText(inputs.get(3).input(), chunkedSparseResult.chunks().get(8).offset()), + getMatchedText(inputs.get(3).getInput(), chunkedSparseResult.chunks().get(8).offset()), startsWith(" passage_input80 ") ); } @@ -868,9 +930,9 @@ public void testMergingListener_Sparse() { public void testListenerErrorsWithWrongNumberOfResponses() { List inputs = List.of( - new ChunkInferenceInput("1st small"), - new ChunkInferenceInput("2nd small"), - new ChunkInferenceInput("3rd small") + new ChunkInferenceTextInput("1st small"), + new ChunkInferenceTextInput("2nd small"), + new ChunkInferenceTextInput("3rd small") ); var failureMessage = new AtomicReference(); @@ -892,10 +954,10 @@ public void onFailure(Exception e) { var batches = new EmbeddingRequestChunker<>(inputs, 10, 100, 0).batchRequestsWithListeners(listener); assertThat(batches, hasSize(1)); - var embeddings = new ArrayList(); - embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); - embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); - batches.get(0).listener().onResponse(new TextEmbeddingFloatResults(embeddings)); + var embeddings = new ArrayList(); + embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); + embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); + batches.get(0).listener().onResponse(new DenseEmbeddingFloatResults(embeddings)); assertEquals("Error the number of embedding responses [2] does not equal the number of requests [3]", failureMessage.get()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java index d6ba10b1932dc..344fc5cf8c480 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java @@ -23,7 +23,7 @@ public void testCallingGetInputs_invokesSupplier() { invoked.set(true); return list; }; - EmbeddingsInput input = new EmbeddingsInput(supplier, null); + EmbeddingsInput input = new EmbeddingsInput(supplier, null, List.of()); // Ensure we don't invoke the supplier until we call getInputs() assertThat(invoked.get(), is(false)); @@ -33,7 +33,7 @@ public void testCallingGetInputs_invokesSupplier() { public void testCallingGetInputsTwice_throws() { Supplier> supplier = () -> List.of("input"); - EmbeddingsInput input = new EmbeddingsInput(supplier, null); + EmbeddingsInput input = new EmbeddingsInput(supplier, null, List.of()); input.getInputs(); var exception = expectThrows(AssertionError.class, input::getInputs); assertThat(exception.getMessage(), is("EmbeddingsInput supplier invoked twice")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 838e4576716ff..8d98d6815a0bd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -395,7 +395,10 @@ private SemanticTextIndexOptions getDefaultSparseVectorIndexOptionsForMapper(Map public void testInvalidTaskTypes() { for (var taskType : TaskType.values()) { - if (taskType == TaskType.TEXT_EMBEDDING || taskType == TaskType.SPARSE_EMBEDDING) { + if (taskType == TaskType.TEXT_EMBEDDING + || taskType == TaskType.SPARSE_EMBEDDING + || taskType == TaskType.IMAGE_EMBEDDING + || taskType == TaskType.MULTIMODAL_EMBEDDING) { continue; } Exception e = expectThrows( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index d1499f4009d0a..d596677f4125f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -25,10 +25,10 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.utils.FloatConversionUtils; import org.elasticsearch.xpack.inference.chunking.NoneChunkingSettings; import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings; @@ -211,7 +211,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingByte(Mode } chunks.add( new EmbeddingResults.Chunk( - new TextEmbeddingByteResults.Embedding(values), + new DenseEmbeddingByteResults.Embedding(values), new ChunkedInference.TextOffset(0, input.length()) ) ); @@ -233,7 +233,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingFloat(Mod } chunks.add( new EmbeddingResults.Chunk( - new TextEmbeddingFloatResults.Embedding(values), + new DenseEmbeddingFloatResults.Embedding(values), new ChunkedInference.TextOffset(0, input.length()) ) ); @@ -415,8 +415,8 @@ public static ChunkedInference toChunkedResult( ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, entryChunk, matchedText); double[] values = parseDenseVector(entryChunk.rawEmbeddings(), embeddingLength, field.contentType()); EmbeddingResults.Embedding embedding = switch (elementType) { - case FLOAT -> new TextEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values)); - case BYTE, BIT -> new TextEmbeddingByteResults.Embedding(byteArrayOf(values)); + case FLOAT -> new DenseEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values)); + case BYTE, BIT -> new DenseEmbeddingByteResults.Embedding(byteArrayOf(values)); }; chunks.add(new EmbeddingResults.Chunk(embedding, offset)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java index 17438d9786ba3..367f5f297020a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java @@ -20,8 +20,8 @@ import org.elasticsearch.test.client.NoOpClient; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; @@ -59,7 +59,7 @@ protected void if (inferenceResults instanceof TextExpansionResults textExpansionResults) { inferenceServiceResults = SparseEmbeddingResults.of(List.of(textExpansionResults)); } else if (inferenceResults instanceof MlTextEmbeddingResults mlTextEmbeddingResults) { - inferenceServiceResults = TextEmbeddingFloatResults.of(List.of(mlTextEmbeddingResults)); + inferenceServiceResults = DenseEmbeddingFloatResults.of(List.of(mlTextEmbeddingResults)); } else { throw new IllegalStateException("Unexpected inference results type [" + inferenceResults.getWriteableName() + "]"); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index b2d7218720a57..593f10360e1bb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -67,8 +67,8 @@ import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.XPackClientPlugin; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; @@ -350,7 +350,7 @@ private InferenceAction.Response generateTextEmbeddingInferenceResponse() { Arrays.fill(inference, 1.0); MlTextEmbeddingResults textEmbeddingResults = new MlTextEmbeddingResults(DEFAULT_RESULTS_FIELD, inference, false); - return new InferenceAction.Response(TextEmbeddingFloatResults.of(List.of(textEmbeddingResults))); + return new InferenceAction.Response(DenseEmbeddingFloatResults.of(List.of(textEmbeddingResults))); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java index 4e60b09530684..ece25247bb5b5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java @@ -23,7 +23,7 @@ import org.elasticsearch.xpack.core.inference.InferenceContext; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; import org.elasticsearch.xpack.inference.InferencePlugin; import org.junit.Before; @@ -234,7 +234,7 @@ public void testExtractProductUseCase_EmptyWhenHeaderValueEmpty() { static InferenceAction.Response createResponse() { return new InferenceAction.Response( - new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1 }))) + new DenseEmbeddingByteResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -1 }))) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java index ec07d7b547004..68513567d50d0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java @@ -452,7 +452,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotValid() throws IOException { new HashMap<>(), InputType.INTERNAL_SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 69e5228a927e7..ae9ebb34c4fb6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -139,7 +139,7 @@ protected void doInfer( PlainActionFuture listener = new PlainActionFuture<>(); - testService.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, null, listener); + testService.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, null, listener, null); listener.actionGet(TIMEOUT); assertEquals(configuredTimeout, capturedTimeout.get()); @@ -178,7 +178,19 @@ protected void doInfer( PlainActionFuture listener = new PlainActionFuture<>(); - testService.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, providedTimeout, listener); + testService.infer( + model, + null, + null, + null, + List.of("test input"), + false, + Map.of(), + InputType.SEARCH, + providedTimeout, + listener, + null + ); listener.actionGet(TIMEOUT); assertEquals(providedTimeout, capturedTimeout.get()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java index cbb119d3e5710..4ea3711ce8f58 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java @@ -543,7 +543,8 @@ private InferenceEventsAssertion streamCompletion() throws Exception { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index 7636c16cd2c6f..34cc6639f5d1f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; @@ -33,9 +34,9 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; @@ -374,7 +375,8 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType_TextEmbedding() t new HashMap<>(), InputType.CLASSIFICATION, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ) ); assertThat( @@ -418,7 +420,8 @@ public void testInfer_ThrowsValidationExceptionForInvalidInputType_SparseEmbeddi new HashMap<>(), InputType.CLASSIFICATION, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ) ); assertThat( @@ -462,7 +465,8 @@ public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOExc new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ) ); assertThat( @@ -492,7 +496,7 @@ public void testChunkedInfer_SparseEmbeddingChunkingSettingsNotSet() throws IOEx } private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettings) throws IOException { - var input = List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar")); + List input = List.of(new ChunkInferenceTextInput("foo"), new ChunkInferenceTextInput("bar")); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -516,7 +520,7 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin var firstResult = results.getFirst(); assertThat(firstResult, instanceOf(ChunkedInferenceEmbedding.class)); Class expectedClass = switch (taskType) { - case TEXT_EMBEDDING -> TextEmbeddingFloatResults.Chunk.class; + case TEXT_EMBEDDING -> DenseEmbeddingFloatResults.Chunk.class; case SPARSE_EMBEDDING -> SparseEmbeddingResults.Chunk.class; default -> null; }; @@ -650,10 +654,10 @@ private AlibabaCloudSearchModel createEmbeddingsModel( ) { public ExecutableAction accept(AlibabaCloudSearchActionVisitor visitor, Map taskSettings) { return (inferenceInputs, timeout, listener) -> { - TextEmbeddingFloatResults results = new TextEmbeddingFloatResults( + DenseEmbeddingFloatResults results = new DenseEmbeddingFloatResults( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123f, -0.0123f }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.0456f, -0.0456f }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123f, -0.0123f }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0456f, -0.0456f }) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java index b09fbf43a8ca4..bd2bae59bea79 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java @@ -21,11 +21,11 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; @@ -88,7 +88,7 @@ public void testExecute_withTextEmbeddingsAction_Success() { float[] values = { 0.1111111f, 0.2222222f, 0.3333333f }; doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); - listener.onResponse(new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(values)))); + listener.onResponse(new DenseEmbeddingFloatResults(List.of(new DenseEmbeddingFloatResults.Embedding(values)))); return Void.TYPE; }).when(sender).send(any(), any(), any(), any()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java index ed8a1185bd846..28e5c9422a970 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java @@ -9,7 +9,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.request.AlibabaCloudSearchRequest; @@ -50,14 +50,14 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException, URI uri = new URI("mock_uri"); when(request.getURI()).thenReturn(uri); - TextEmbeddingFloatResults parsedResults = AlibabaCloudSearchEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = AlibabaCloudSearchEmbeddingsResponseEntity.fromResponse( request, new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f }))) + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f }))) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index d7eb32861da92..a5b44f4160ca5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; @@ -38,7 +38,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -979,7 +979,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOExc new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -1023,7 +1024,7 @@ public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException ); var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender() ) { - var results = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))); + var results = new DenseEmbeddingFloatResults(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))); requestSender.enqueue(results); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1036,7 +1037,8 @@ public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1064,8 +1066,8 @@ public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException ) ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { - var results = new TextEmbeddingFloatResults( - List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) + var results = new DenseEmbeddingFloatResults( + List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) ); requestSender.enqueue(results); @@ -1088,7 +1090,8 @@ public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException new HashMap<>(), InputType.CLASSIFICATION, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1139,7 +1142,8 @@ public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1277,7 +1281,8 @@ public void testInfer_UnauthorizedResponse() throws IOException { new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var exceptionThrown = assertThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -1340,14 +1345,14 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { { - var mockResults1 = new TextEmbeddingFloatResults( - List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) + var mockResults1 = new DenseEmbeddingFloatResults( + List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) ); requestSender.enqueue(mockResults1); } { - var mockResults2 = new TextEmbeddingFloatResults( - List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.223F, 0.278F })) + var mockResults2 = new DenseEmbeddingFloatResults( + List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.223F, 0.278F })) ); requestSender.enqueue(mockResults2); } @@ -1356,7 +1361,7 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + List.of(new ChunkInferenceTextInput("a"), new ChunkInferenceTextInput("bb")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1370,10 +1375,10 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123F, 0.678F }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1382,10 +1387,10 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.223F, 0.278F }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java index 5dd42dc66485f..50ae3b4182f82 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java @@ -15,7 +15,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; @@ -53,8 +53,8 @@ public void shutdown() throws IOException { public void testEmbeddingsRequestAction_Titan() throws IOException { var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); - var mockedFloatResults = List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })); - var mockedResult = new TextEmbeddingFloatResults(mockedFloatResults); + var mockedFloatResults = List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })); + var mockedResult = new DenseEmbeddingFloatResults(mockedFloatResults); try (var sender = new AmazonBedrockMockRequestSender()) { sender.enqueue(mockedResult); var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT); @@ -91,8 +91,8 @@ public void testEmbeddingsRequestAction_Titan() throws IOException { public void testEmbeddingsRequestAction_Cohere() throws IOException { var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); - var mockedFloatResults = List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })); - var mockedResult = new TextEmbeddingFloatResults(mockedFloatResults); + var mockedFloatResults = List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })); + var mockedResult = new DenseEmbeddingFloatResults(mockedFloatResults); try (var sender = new AmazonBedrockMockRequestSender()) { sender.enqueue(mockedResult); var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 7de44857cf58e..8babe432997ca 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -466,7 +466,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -523,7 +524,8 @@ public void testInfer_SendsCompletionRequest() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -598,7 +600,8 @@ private InferenceEventsAssertion streamChatCompletion() throws Exception { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index f1531929db8c3..6d83dc784d897 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; @@ -39,8 +39,8 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -1199,7 +1199,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOExc new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -1239,7 +1240,8 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept new HashMap<>(), InputType.CLASSIFICATION, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ) ); assertThat( @@ -1331,7 +1333,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + List.of(new ChunkInferenceTextInput("a"), new ChunkInferenceTextInput("bb")), new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1345,10 +1347,10 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.0123f, -0.0123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1357,10 +1359,10 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 1.0123f, -1.0123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1403,7 +1405,8 @@ public void testInfer_WithChatCompletionModel() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1440,7 +1443,8 @@ public void testInfer_WithRerankModel() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1497,7 +1501,8 @@ public void testInfer_UnauthorisedResponse() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -1555,7 +1560,8 @@ private InferenceEventsAssertion streamChatCompletion() throws Exception { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java index 3da3598a4637a..e2ebd72b17da5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java @@ -9,7 +9,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -50,11 +50,11 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { var entity = new AzureAiStudioEmbeddingsResponseEntity(); - var parsedResults = (TextEmbeddingFloatResults) entity.apply( + var parsedResults = (DenseEmbeddingFloatResults) entity.apply( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(0.014539449F, -0.015288644F))))); + assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(0.014539449F, -0.015288644F))))); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 55b10f2e2b9d7..edb661dc12f27 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; @@ -37,7 +37,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -765,7 +765,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -823,7 +824,8 @@ public void testInfer_SendsRequest() throws IOException, URISyntaxException { new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -919,7 +921,8 @@ public void testInfer_UnauthorisedResponse() throws IOException, URISyntaxExcept new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -990,7 +993,7 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + List.of(new ChunkInferenceTextInput("a"), new ChunkInferenceTextInput("bb")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1004,10 +1007,10 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1016,10 +1019,10 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 1.123f, -1.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1088,7 +1091,8 @@ private InferenceEventsAssertion streamChatCompletion() throws Exception { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 67f545a8104de..357e5d99c90e5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; @@ -39,8 +39,8 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -793,7 +793,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -819,9 +820,16 @@ public void testInfer_SendsRequest() throws IOException { String responseJson = """ { "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "float": [ [ @@ -863,7 +871,8 @@ public void testInfer_SendsRequest() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -880,7 +889,18 @@ public void testInfer_SendsRequest() throws IOException { var requestMap = entityAsMap(webServer.requests().get(0).getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of("texts", List.of("abc"), "model", "model", "input_type", "search_document", "embedding_types", List.of("float"))) + is( + Map.of( + "inputs", + List.of(Map.of("content", List.of(Map.of("text", "abc", "type", "text")))), + "model", + "model", + "input_type", + "search_document", + "embedding_types", + List.of("float") + ) + ) ); } } @@ -964,7 +984,8 @@ public void testInfer_UnauthorisedResponse() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -982,9 +1003,16 @@ public void testInfer_SetsInputTypeToIngest_FromInferParameter_WhenTaskSettingsA String responseJson = """ { "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "float": [ [ @@ -1026,7 +1054,8 @@ public void testInfer_SetsInputTypeToIngest_FromInferParameter_WhenTaskSettingsA new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1044,7 +1073,18 @@ public void testInfer_SetsInputTypeToIngest_FromInferParameter_WhenTaskSettingsA var requestMap = entityAsMap(webServer.requests().get(0).getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of("texts", List.of("abc"), "model", "model", "input_type", "search_document", "embedding_types", List.of("float"))) + is( + Map.of( + "inputs", + List.of(Map.of("content", List.of(Map.of("text", "abc", "type", "text")))), + "model", + "model", + "input_type", + "search_document", + "embedding_types", + List.of("float") + ) + ) ); } } @@ -1058,9 +1098,16 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs String responseJson = """ { "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "float": [ [ @@ -1102,7 +1149,8 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, null), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1119,7 +1167,18 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs var requestMap = entityAsMap(webServer.requests().get(0).getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of("texts", List.of("abc"), "model", "model", "input_type", "search_document", "embedding_types", List.of("float"))) + is( + Map.of( + "inputs", + List.of(Map.of("content", List.of(Map.of("text", "abc", "type", "text")))), + "model", + "model", + "input_type", + "search_document", + "embedding_types", + List.of("float") + ) + ) ); } } @@ -1177,7 +1236,8 @@ public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspec new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1207,9 +1267,16 @@ public void testInfer_DefaultsInputType_WhenNotPresentInTaskSettings_AndUnspecif String responseJson = """ { "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "float": [ [ @@ -1252,7 +1319,8 @@ public void testInfer_DefaultsInputType_WhenNotPresentInTaskSettings_AndUnspecif new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); listener.actionGet(TIMEOUT); @@ -1261,7 +1329,18 @@ public void testInfer_DefaultsInputType_WhenNotPresentInTaskSettings_AndUnspecif var requestMap = entityAsMap(webServer.requests().get(0).getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of("texts", List.of("abc"), "model", "model", "embedding_types", List.of("float"), "input_type", "search_query")) + is( + Map.of( + "inputs", + List.of(Map.of("content", List.of(Map.of("text", "abc", "type", "text")))), + "model", + "model", + "embedding_types", + List.of("float"), + "input_type", + "search_query" + ) + ) ); } } @@ -1305,9 +1384,16 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { String responseJson = """ { "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "float": [ [ @@ -1338,7 +1424,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + List.of(new ChunkInferenceTextInput("a"), new ChunkInferenceTextInput("bb")), new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1354,7 +1440,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1365,7 +1451,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); assertArrayEquals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1381,7 +1467,21 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { var requestMap = entityAsMap(webServer.requests().get(0).getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("float"), "input_type", "search_query")) + is( + Map.of( + "inputs", + List.of( + Map.of("content", List.of(Map.of("text", "a", "type", "text"))), + Map.of("content", List.of(Map.of("text", "bb", "type", "text"))) + ), + "model", + "model", + "embedding_types", + List.of("float"), + "input_type", + "search_query" + ) + ) ); } } @@ -1395,9 +1495,16 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { String responseJson = """ { "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "int8": [ [ @@ -1437,7 +1544,7 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + List.of(new ChunkInferenceTextInput("a"), new ChunkInferenceTextInput("bb")), new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1451,10 +1558,10 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { var byteResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(byteResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), byteResult.chunks().get(0).offset()); - assertThat(byteResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); + assertThat(byteResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); assertArrayEquals( new byte[] { 23, -23 }, - ((TextEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values() ); } { @@ -1462,10 +1569,10 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { var byteResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(byteResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), byteResult.chunks().get(0).offset()); - assertThat(byteResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); + assertThat(byteResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); assertArrayEquals( new byte[] { 24, -24 }, - ((TextEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values() ); } @@ -1480,7 +1587,21 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { var requestMap = entityAsMap(webServer.requests().get(0).getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("int8"), "input_type", "search_query")) + is( + Map.of( + "inputs", + List.of( + Map.of("content", List.of(Map.of("text", "a", "type", "text"))), + Map.of("content", List.of(Map.of("text", "bb", "type", "text"))) + ), + "model", + "model", + "embedding_types", + List.of("int8"), + "input_type", + "search_query" + ) + ) ); } } @@ -1522,7 +1643,8 @@ private InferenceEventsAssertion streamChatCompletion() throws Exception { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); @@ -1541,42 +1663,44 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception { @SuppressWarnings("checkstyle:LineLength") public void testGetConfiguration() throws Exception { try (var service = createCohereService()) { - String content = XContentHelper.stripWhitespace(""" - { - "service": "cohere", - "name": "Cohere", - "task_types": ["text_embedding", "rerank", "completion"], - "configurations": { - "api_key": { - "description": "API Key for the provider you're connecting to.", - "label": "API Key", - "required": true, - "sensitive": true, - "updatable": true, - "type": "str", - "supported_task_types": ["text_embedding", "rerank", "completion"] - }, - "model_id": { - "description": "The name of the model to use for the inference task.", - "label": "Model ID", - "required": false, - "sensitive": false, - "updatable": false, - "type": "str", - "supported_task_types": ["text_embedding", "rerank", "completion"] - }, - "rate_limit.requests_per_minute": { - "description": "Minimize the number of rate limit errors.", - "label": "Rate Limit", - "required": false, - "sensitive": false, - "updatable": false, - "type": "int", - "supported_task_types": ["text_embedding", "rerank", "completion"] + String content = XContentHelper.stripWhitespace( + """ + { + "service": "cohere", + "name": "Cohere", + "task_types": ["text_embedding", "rerank", "completion", "image_embedding", "multimodal_embedding"], + "configurations": { + "api_key": { + "description": "API Key for the provider you're connecting to.", + "label": "API Key", + "required": true, + "sensitive": true, + "updatable": true, + "type": "str", + "supported_task_types": ["text_embedding", "rerank", "completion", "image_embedding", "multimodal_embedding"] + }, + "model_id": { + "description": "The name of the model to use for the inference task.", + "label": "Model ID", + "required": false, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "rerank", "completion", "image_embedding", "multimodal_embedding"] + }, + "rate_limit.requests_per_minute": { + "description": "Minimize the number of rate limit errors.", + "label": "Rate Limit", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "rerank", "completion", "image_embedding", "multimodal_embedding"] + } } } - } - """); + """ + ); InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( new BytesArray(content), XContentType.JSON diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java index 8ab76bb728802..e9be6a1faa66b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java @@ -81,8 +81,15 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { String responseJson = """ { "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } ], "embeddings": { "float": [ @@ -137,8 +144,8 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { requestMap, is( Map.of( - "texts", - List.of("abc"), + "inputs", + List.of(Map.of("content", List.of(Map.of("text", "abc", "type", "text")))), "model", "model", "input_type", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java index ca068d3e1859d..b4012edc6c824 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java @@ -89,8 +89,15 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { String responseJson = """ { "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "abc" + } + ] + } ], "embeddings": { "float": [ @@ -148,8 +155,8 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { requestMap, is( Map.of( - "texts", - List.of("abc"), + "inputs", + List.of(Map.of("content", List.of(Map.of("text", "abc", "type", "text")))), "model", "model", "input_type", @@ -173,8 +180,15 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I String responseJson = """ { "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "abc" + } + ] + } ], "embeddings": { "int8": [ @@ -229,8 +243,8 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I requestMap, is( Map.of( - "texts", - List.of("abc"), + "inputs", + List.of(Map.of("content", List.of(Map.of("text", "abc", "type", "text")))), "model", "model", "input_type", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java index fd380b8fd973d..fd229a13e2051 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -134,7 +135,8 @@ public static CohereEmbeddingsModel createModel( ), taskSettings, chunkingSettings, - new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())), + TaskType.TEXT_EMBEDDING ); } @@ -177,7 +179,8 @@ public static CohereEmbeddingsModel createModel( ), taskSettings, null, - new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())), + TaskType.TEXT_EMBEDDING ); } @@ -207,7 +210,8 @@ public static CohereEmbeddingsModel createModel( ), taskSettings, null, - new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())), + TaskType.TEXT_EMBEDDING ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java index a7e009d63a903..c8a25d8fa1437 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java @@ -10,6 +10,7 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentBuilder; @@ -38,6 +39,7 @@ public void testCreateRequest() throws IOException { var inputType = InputTypeTests.randomWithoutUnspecified(); var request = createRequest( List.of("abc"), + null, inputType, CohereEmbeddingsModelTests.createModel( null, @@ -64,7 +66,7 @@ public void testCreateRequest() throws IOException { ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - MatcherAssert.assertThat(requestMap.get("texts"), is(List.of("abc"))); + MatcherAssert.assertThat(requestMap.get("inputs"), is(List.of(Map.of("content", List.of(Map.of("text", "abc", "type", "text")))))); MatcherAssert.assertThat(requestMap.get("embedding_types"), is(List.of("float"))); MatcherAssert.assertThat(requestMap.get("model"), is("model id")); MatcherAssert.assertThat(requestMap.get("truncate"), is("start")); @@ -75,6 +77,7 @@ public void testCreateRequest_WithTaskSettingsInputType() throws IOException { var inputType = InputTypeTests.randomWithoutUnspecified(); var request = createRequest( List.of("abc"), + null, InputType.UNSPECIFIED, CohereEmbeddingsModelTests.createModel( "url", @@ -100,6 +103,7 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() th var inputType = InputTypeTests.randomWithoutUnspecified(); var request = createRequest( List.of("abc"), + null, inputType, CohereEmbeddingsModelTests.createModel( "http://localhost", @@ -126,7 +130,7 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() th ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - MatcherAssert.assertThat(requestMap.get("texts"), is(List.of("abc"))); + MatcherAssert.assertThat(requestMap.get("inputs"), is(List.of(Map.of("content", List.of(Map.of("text", "abc", "type", "text")))))); MatcherAssert.assertThat(requestMap.get("model"), is("model")); MatcherAssert.assertThat(requestMap.get("embedding_types"), is(List.of("int8"))); MatcherAssert.assertThat(requestMap.get("truncate"), is("end")); @@ -137,6 +141,7 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() thr var inputType = InputTypeTests.randomWithoutUnspecified(); var request = createRequest( List.of("abc"), + null, inputType, CohereEmbeddingsModelTests.createModel( null, @@ -162,7 +167,7 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() thr ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - MatcherAssert.assertThat(requestMap.get("texts"), is(List.of("abc"))); + MatcherAssert.assertThat(requestMap.get("inputs"), is(List.of(Map.of("content", List.of(Map.of("text", "abc", "type", "text")))))); MatcherAssert.assertThat(requestMap.get("model"), is("model")); MatcherAssert.assertThat(requestMap.get("embedding_types"), is(List.of("binary"))); MatcherAssert.assertThat(requestMap.get("truncate"), is("end")); @@ -173,6 +178,7 @@ public void testCreateRequest_TruncateNone() throws IOException { var inputType = InputTypeTests.randomWithoutUnspecified(); var request = createRequest( List.of("abc"), + null, inputType, CohereEmbeddingsModelTests.createModel( null, @@ -199,7 +205,7 @@ public void testCreateRequest_TruncateNone() throws IOException { ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - MatcherAssert.assertThat(requestMap.get("texts"), is(List.of("abc"))); + MatcherAssert.assertThat(requestMap.get("inputs"), is(List.of(Map.of("content", List.of(Map.of("text", "abc", "type", "text")))))); MatcherAssert.assertThat(requestMap.get("embedding_types"), is(List.of("float"))); MatcherAssert.assertThat(requestMap.get("truncate"), is("none")); validateInputType(requestMap, null, inputType); @@ -209,6 +215,7 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException var entity = createRequest( "cohere model", List.of("abc"), + List.of("def"), InputType.INTERNAL_INGEST, new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.START), CohereEmbeddingType.FLOAT @@ -218,14 +225,40 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException entity.toXContent(builder, null); String xContentResult = Strings.toString(builder); - MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" - {"texts":["abc"],"model":"cohere model","input_type":"search_document","embedding_types":["float"],"truncate":"start"}""")); + String expectedResult = XContentHelper.stripWhitespace(""" + { + "inputs": [ + { + "content": [ + { + "type": "text", "text": "abc" + } + ] + }, + { + "content": [ + { + "type": "image_url", "image_url": { + "url": "def" + } + } + ] + } + ], + "model":"cohere model", + "input_type":"search_document", + "embedding_types":["float"], + "truncate":"start" + } + """); + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(expectedResult)); } public void testXContent_TaskSettingsInputType_EmbeddingTypesInt8_TruncateNone() throws IOException { var entity = createRequest( "cohere model", List.of("abc"), + null, InputType.INGEST, new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), CohereEmbeddingType.INT8 @@ -235,14 +268,31 @@ public void testXContent_TaskSettingsInputType_EmbeddingTypesInt8_TruncateNone() entity.toXContent(builder, null); String xContentResult = Strings.toString(builder); - MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" - {"texts":["abc"],"model":"cohere model","input_type":"search_document","embedding_types":["int8"],"truncate":"none"}""")); + String expectedResult = XContentHelper.stripWhitespace(""" + { + "inputs": [ + { + "content": [ + { + "type": "text", "text": "abc" + } + ] + } + ], + "model":"cohere model", + "input_type":"search_document", + "embedding_types":["int8"], + "truncate":"none" + } + """); + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(expectedResult)); } public void testXContent_InternalInputType_EmbeddingTypesByte_TruncateNone() throws IOException { var entity = createRequest( "cohere model", List.of("abc"), + null, InputType.INTERNAL_SEARCH, new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), CohereEmbeddingType.BYTE @@ -252,14 +302,31 @@ public void testXContent_InternalInputType_EmbeddingTypesByte_TruncateNone() thr entity.toXContent(builder, null); String xContentResult = Strings.toString(builder); - MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" - {"texts":["abc"],"model":"cohere model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); + String expectedResult = XContentHelper.stripWhitespace(""" + { + "inputs": [ + { + "content": [ + { + "type": "text", "text": "abc" + } + ] + } + ], + "model":"cohere model", + "input_type":"search_query", + "embedding_types":["int8"], + "truncate":"none" + } + """); + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(expectedResult)); } public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() throws IOException { var entity = createRequest( "cohere model", List.of("abc"), + null, InputType.SEARCH, new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), CohereEmbeddingType.BINARY @@ -269,14 +336,31 @@ public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() thr entity.toXContent(builder, null); String xContentResult = Strings.toString(builder); - MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" - {"texts":["abc"],"model":"cohere model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); + String expectedResult = XContentHelper.stripWhitespace(""" + { + "inputs": [ + { + "content": [ + { + "type": "text", "text": "abc" + } + ] + } + ], + "model":"cohere model", + "input_type":"search_query", + "embedding_types":["binary"], + "truncate":"none" + } + """); + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(expectedResult)); } public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws IOException { var entity = createRequest( "cohere model", List.of("abc"), + null, InputType.SEARCH, new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), CohereEmbeddingType.BIT @@ -286,23 +370,45 @@ public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws entity.toXContent(builder, null); String xContentResult = Strings.toString(builder); - MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" - {"texts":["abc"],"model":"cohere model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); + String expectedResult = XContentHelper.stripWhitespace(""" + { + "inputs": [ + { + "content": [ + { + "type": "text", "text": "abc" + } + ] + } + ], + "model":"cohere model", + "input_type":"search_query", + "embedding_types":["binary"], + "truncate":"none" + } + """); + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(expectedResult)); } - public static CohereV2EmbeddingsRequest createRequest(List input, InputType inputType, CohereEmbeddingsModel model) { - return new CohereV2EmbeddingsRequest(input, inputType, model); + public static CohereV2EmbeddingsRequest createRequest( + List input, + List imageUrls, + InputType inputType, + CohereEmbeddingsModel model + ) { + return new CohereV2EmbeddingsRequest(input, inputType, model, imageUrls); } public static CohereV2EmbeddingsRequest createRequest( String modelId, List input, + List imageUrls, InputType inputType, CohereEmbeddingsTaskSettings taskSettings, CohereEmbeddingType embeddingType ) { var model = CohereEmbeddingsModelTests.createModel(null, "secret", taskSettings, null, null, modelId, embeddingType); - return new CohereV2EmbeddingsRequest(input, inputType, model); + return new CohereV2EmbeddingsRequest(input, inputType, model, imageUrls); } private void validateInputType(Map requestMap, InputType taskSettingsInputType, InputType requestInputType) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java index 6df356bfe0a80..f0399602da613 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java @@ -9,10 +9,11 @@ import org.apache.http.HttpResponse; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.hamcrest.MatcherAssert; @@ -30,9 +31,16 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { String responseJson = """ { "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": [ [ -0.0018434525, @@ -53,13 +61,14 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { InferenceServiceResults parsedResults = CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)), + TaskType.TEXT_EMBEDDING ); - MatcherAssert.assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class)); + MatcherAssert.assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class)); MatcherAssert.assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }))) + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }))) ); } @@ -67,9 +76,16 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws String responseJson = """ { "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "float": [ [ @@ -90,14 +106,15 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws } """; - TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)), + TaskType.TEXT_EMBEDDING ); MatcherAssert.assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }))) + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }))) ); } @@ -105,9 +122,16 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntry() throws IOExcepti String responseJson = """ { "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "float": [ [ @@ -134,14 +158,15 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntry() throws IOExcepti } """; - TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)), + TaskType.TEXT_EMBEDDING ); MatcherAssert.assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }))) + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }))) ); } @@ -149,9 +174,16 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntryInt8_WithInvalidFir String responseJson = """ { "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "invalid_type": [ [ @@ -178,14 +210,15 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntryInt8_WithInvalidFir } """; - TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingByteResults parsedResults = (DenseEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)), + TaskType.TEXT_EMBEDDING ); MatcherAssert.assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 }))) + is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 }))) ); } @@ -193,9 +226,16 @@ public void testFromResponse_ParsesBytes() throws IOException { String responseJson = """ { "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "int8": [ [ @@ -216,14 +256,15 @@ public void testFromResponse_ParsesBytes() throws IOException { } """; - TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingByteResults parsedResults = (DenseEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)), + TaskType.TEXT_EMBEDDING ); MatcherAssert.assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 }))) + is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 }))) ); } @@ -231,9 +272,16 @@ public void testFromResponse_ParsesBytes_FromBinaryEmbeddingsEntry() throws IOEx String responseJson = """ { "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "binary": [ [ @@ -257,14 +305,15 @@ public void testFromResponse_ParsesBytes_FromBinaryEmbeddingsEntry() throws IOEx } """; - TextEmbeddingBitResults parsedResults = (TextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingBitResults parsedResults = (DenseEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)), + TaskType.TEXT_EMBEDDING ); MatcherAssert.assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) + is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) ); } @@ -272,9 +321,16 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException String responseJson = """ { "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": [ [ -0.0018434525, @@ -297,17 +353,18 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException } """; - TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)), + TaskType.TEXT_EMBEDDING ); MatcherAssert.assertThat( parsedResults.embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }), - new TextEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F }) ) ) ); @@ -317,9 +374,16 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw String responseJson = """ { "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "float": [ [ @@ -344,17 +408,18 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw } """; - TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)), + TaskType.TEXT_EMBEDDING ); MatcherAssert.assertThat( parsedResults.embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }), - new TextEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F }) ) ) ); @@ -397,17 +462,18 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat_Binary( } """; - TextEmbeddingBitResults parsedResults = (TextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingBitResults parsedResults = (DenseEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)), + TaskType.TEXT_EMBEDDING ); MatcherAssert.assertThat( parsedResults.embeddings(), is( List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 34, (byte) -64, (byte) 97, (byte) 65, (byte) -42 }) + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67 }), + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 34, (byte) -64, (byte) 97, (byte) 65, (byte) -42 }) ) ) ); @@ -417,9 +483,16 @@ public void testFromResponse_FailsWhenEmbeddingsFieldIsNotPresent() { String responseJson = """ { "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings_not_here": [ [ -0.0018434525, @@ -442,7 +515,8 @@ public void testFromResponse_FailsWhenEmbeddingsFieldIsNotPresent() { IllegalStateException.class, () -> CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)), + TaskType.TEXT_EMBEDDING ) ); @@ -456,9 +530,16 @@ public void testFromResponse_FailsWhenEmbeddingsByteValue_IsOutsideByteRange_Neg String responseJson = """ { "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "int8": [ [ @@ -483,7 +564,8 @@ public void testFromResponse_FailsWhenEmbeddingsByteValue_IsOutsideByteRange_Neg IllegalArgumentException.class, () -> CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)), + TaskType.TEXT_EMBEDDING ) ); @@ -494,9 +576,16 @@ public void testFromResponse_FailsWhenEmbeddingsByteValue_IsOutsideByteRange_Pos String responseJson = """ { "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "int8": [ [ @@ -521,7 +610,8 @@ public void testFromResponse_FailsWhenEmbeddingsByteValue_IsOutsideByteRange_Pos IllegalArgumentException.class, () -> CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)), + TaskType.TEXT_EMBEDDING ) ); @@ -532,9 +622,16 @@ public void testFromResponse_FailsWhenEmbeddingsBinaryValue_IsOutsideByteRange_N String responseJson = """ { "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "binary": [ [ @@ -559,7 +656,8 @@ public void testFromResponse_FailsWhenEmbeddingsBinaryValue_IsOutsideByteRange_N IllegalArgumentException.class, () -> CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)), + TaskType.TEXT_EMBEDDING ) ); @@ -570,9 +668,16 @@ public void testFromResponse_FailsWhenEmbeddingsBinaryValue_IsOutsideByteRange_P String responseJson = """ { "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "binary": [ [ @@ -597,7 +702,8 @@ public void testFromResponse_FailsWhenEmbeddingsBinaryValue_IsOutsideByteRange_P IllegalArgumentException.class, () -> CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)), + TaskType.TEXT_EMBEDDING ) ); @@ -608,9 +714,16 @@ public void testFromResponse_FailsToFindAValidEmbeddingType() { String responseJson = """ { "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", - "texts": [ - "hello" - ], + "inputs": [ + { + "content": [ + { + "type": "text", + "text": "hello" + } + ] + } + ], "embeddings": { "invalid_type": [ [ @@ -635,7 +748,8 @@ public void testFromResponse_FailsToFindAValidEmbeddingType() { IllegalStateException.class, () -> CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)), + TaskType.TEXT_EMBEDDING ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index 55bb98705a2a3..f074530944dd6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -13,7 +13,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; @@ -29,9 +29,9 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -335,7 +335,8 @@ public void testInfer_ReturnsAnError_WithoutParsingTheResponseBody() throws IOEx new HashMap<>(), InputType.INTERNAL_SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -393,16 +394,17 @@ public void testInfer_HandlesTextEmbeddingRequest_OpenAI_Format() throws IOExcep new HashMap<>(), InputType.INTERNAL_SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); InferenceServiceResults results = listener.actionGet(TIMEOUT); - assertThat(results, instanceOf(TextEmbeddingFloatResults.class)); + assertThat(results, instanceOf(DenseEmbeddingFloatResults.class)); - var embeddingResults = (TextEmbeddingFloatResults) results; + var embeddingResults = (DenseEmbeddingFloatResults) results; assertThat( embeddingResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }))) + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }))) ); } } @@ -465,7 +467,8 @@ public void testInfer_HandlesRerankRequest_Cohere_Format() throws IOException { new HashMap<>(), InputType.INTERNAL_SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); InferenceServiceResults results = listener.actionGet(TIMEOUT); @@ -536,7 +539,8 @@ public void testInfer_HandlesCompletionRequest_OpenAI_Format() throws IOExceptio new HashMap<>(), InputType.INTERNAL_SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); InferenceServiceResults results = listener.actionGet(TIMEOUT); @@ -601,7 +605,8 @@ public void testInfer_HandlesSparseEmbeddingRequest_Alibaba_Format() throws IOEx new HashMap<>(), InputType.INTERNAL_SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); InferenceServiceResults results = listener.actionGet(TIMEOUT); @@ -703,7 +708,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + List.of(new ChunkInferenceTextInput("a"), new ChunkInferenceTextInput("bb")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -717,10 +722,10 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -729,10 +734,10 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -778,7 +783,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("a")), + List.of(new ChunkInferenceTextInput("a")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -792,10 +797,10 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java index e53add6733aca..6f8e1b043f035 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java @@ -13,9 +13,9 @@ import org.elasticsearch.inference.WeightedToken; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; @@ -73,10 +73,10 @@ public void testFromTextEmbeddingResponse() throws IOException { new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(results, instanceOf(TextEmbeddingFloatResults.class)); + assertThat(results, instanceOf(DenseEmbeddingFloatResults.class)); assertThat( - ((TextEmbeddingFloatResults) results).embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f }))) + ((DenseEmbeddingFloatResults) results).embeddings(), + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f }))) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java index 7796b5e1e7f6b..2ec0316b814a8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java @@ -16,9 +16,9 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType; @@ -121,13 +121,17 @@ public void testParse() throws IOException { """; var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); - TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse( + DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) parser.parse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); assertThat( parsedResults, - is(new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))) + is( + new DenseEmbeddingFloatResults( + List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })) + ) + ) ); } @@ -154,11 +158,14 @@ public void testParseByte() throws IOException { """; var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BYTE); - TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) parser.parse( + DenseEmbeddingByteResults parsedResults = (DenseEmbeddingByteResults) parser.parse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults, is(new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { 1, -2 }))))); + assertThat( + parsedResults, + is(new DenseEmbeddingByteResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, -2 })))) + ); } public void testParseBit() throws IOException { @@ -184,11 +191,11 @@ public void testParseBit() throws IOException { """; var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BIT); - TextEmbeddingBitResults parsedResults = (TextEmbeddingBitResults) parser.parse( + DenseEmbeddingBitResults parsedResults = (DenseEmbeddingBitResults) parser.parse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults, is(new TextEmbeddingBitResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { 1, -2 }))))); + assertThat(parsedResults, is(new DenseEmbeddingBitResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, -2 }))))); } public void testParse_MultipleEmbeddings() throws IOException { @@ -222,17 +229,17 @@ public void testParse_MultipleEmbeddings() throws IOException { """; var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); - TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse( + DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) parser.parse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); assertThat( parsedResults, is( - new TextEmbeddingFloatResults( + new DenseEmbeddingFloatResults( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 1F, -2F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 1F, -2F }) ) ) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index 498fbca10b5a8..7ccab221f7a9e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -231,7 +231,7 @@ public void testDoInfer() throws Exception { try (var service = createService()) { var model = createModel(service, TaskType.COMPLETION); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, null, null, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); + service.infer(model, null, null, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener, null); var result = listener.actionGet(TIMEOUT); assertThat(result, isA(ChatCompletionResults.class)); var completionResults = (ChatCompletionResults) result; @@ -254,7 +254,7 @@ public void testDoInferStream() throws Exception { try (var service = createService()) { var model = createModel(service, TaskType.COMPLETION); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, null, null, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); + service.infer(model, null, null, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener, null); InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream().hasNoErrors().hasEvent(""" {"completion":[{"delta":"hello, world"}]}"""); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index d861d5b2bb47b..333ec066fe220 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -40,8 +40,8 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; @@ -425,7 +425,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -465,7 +466,8 @@ public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOExc new HashMap<>(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ) ); @@ -496,7 +498,8 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -549,7 +552,8 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { new HashMap<>(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -605,7 +609,8 @@ public void testRerank_SendsRerankRequest() throws IOException { new HashMap<>(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -678,7 +683,8 @@ public void testInfer_PropagatesProductUseCaseHeader() throws IOException { new HashMap<>(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -811,7 +817,7 @@ public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("hello world"), new ChunkInferenceInput("dense embedding")), + List.of(new ChunkInferenceTextInput("hello world"), new ChunkInferenceTextInput("dense embedding")), new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -876,7 +882,7 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("hello world"), new ChunkInferenceInput("dense embedding")), + List.of(new ChunkInferenceTextInput("hello world"), new ChunkInferenceTextInput("dense embedding")), new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -892,9 +898,9 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio var denseResult = (ChunkedInferenceEmbedding) results.getFirst(); assertThat(denseResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, "hello world".length()), denseResult.chunks().getFirst().offset()); - assertThat(denseResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(denseResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); - var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding(); + var embedding = (DenseEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding(); assertArrayEquals(new float[] { 0.123f, -0.456f, 0.789f }, embedding.values(), 0.0f); } @@ -904,9 +910,9 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio var denseResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(denseResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, "dense embedding".length()), denseResult.chunks().getFirst().offset()); - assertThat(denseResult.chunks().getFirst().embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(denseResult.chunks().getFirst().embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); - var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding(); + var embedding = (DenseEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding(); assertArrayEquals(new float[] { 0.987f, -0.654f, 0.321f }, embedding.values(), 0.0f); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index 7ee6d817f899c..3595075a5a32b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -20,9 +20,9 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -298,8 +298,8 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() var result = listener.actionGet(TIMEOUT); - assertThat(result, instanceOf(TextEmbeddingFloatResults.class)); - var textEmbeddingResults = (TextEmbeddingFloatResults) result; + assertThat(result, instanceOf(DenseEmbeddingFloatResults.class)); + var textEmbeddingResults = (DenseEmbeddingFloatResults) result; assertThat(textEmbeddingResults.embeddings(), hasSize(2)); var firstEmbedding = textEmbeddingResults.embeddings().get(0); @@ -354,8 +354,8 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_W var result = listener.actionGet(TIMEOUT); - assertThat(result, instanceOf(TextEmbeddingFloatResults.class)); - var textEmbeddingResults = (TextEmbeddingFloatResults) result; + assertThat(result, instanceOf(DenseEmbeddingFloatResults.class)); + var textEmbeddingResults = (DenseEmbeddingFloatResults) result; assertThat(textEmbeddingResults.embeddings(), hasSize(1)); var embedding = textEmbeddingResults.embeddings().get(0); @@ -447,8 +447,8 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_E var result = listener.actionGet(TIMEOUT); - assertThat(result, instanceOf(TextEmbeddingFloatResults.class)); - var textEmbeddingResults = (TextEmbeddingFloatResults) result; + assertThat(result, instanceOf(DenseEmbeddingFloatResults.class)); + var textEmbeddingResults = (DenseEmbeddingFloatResults) result; assertThat(textEmbeddingResults.embeddings(), hasSize(0)); assertThat(webServer.requests(), hasSize(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java index 2883a1ab73c21..79721b95af067 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java @@ -9,7 +9,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity; @@ -35,7 +35,7 @@ public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_NoMeta() throw } """; - TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -64,7 +64,7 @@ public void testDenseTextEmbeddingsResponse_MultipleEmbeddingsInData_NoMeta() th } """; - TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -85,7 +85,7 @@ public void testDenseTextEmbeddingsResponse_EmptyData() throws Exception { } """; - TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -111,7 +111,7 @@ public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_IgnoresMeta() } """; - TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 3af19bf46c62e..579390d64c0bd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -26,7 +26,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -53,8 +53,8 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; @@ -1018,20 +1018,20 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws Interru assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class)); var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0); assertThat(result1.chunks(), hasSize(1)); - assertThat(result1.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(result1.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( ((MlTextEmbeddingResults) mlTrainedModelResults.get(0)).getInferenceAsFloat(), - ((TextEmbeddingFloatResults.Embedding) result1.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) result1.chunks().get(0).embedding()).values(), 0.0001f ); assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset()); assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class)); var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1); assertThat(result2.chunks(), hasSize(1)); - assertThat(result2.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(result2.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( ((MlTextEmbeddingResults) mlTrainedModelResults.get(1)).getInferenceAsFloat(), - ((TextEmbeddingFloatResults.Embedding) result2.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) result2.chunks().get(0).embedding()).values(), 0.0001f ); assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset()); @@ -1045,7 +1045,7 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws Interru service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + List.of(new ChunkInferenceTextInput("a"), new ChunkInferenceTextInput("bb")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1117,7 +1117,7 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws Int service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + List.of(new ChunkInferenceTextInput("a"), new ChunkInferenceTextInput("bb")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1189,7 +1189,7 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws Inte service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + List.of(new ChunkInferenceTextInput("a"), new ChunkInferenceTextInput("bb")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1235,7 +1235,7 @@ public void testChunkInferSetsTokenization() { service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar")), + List.of(new ChunkInferenceTextInput("foo"), new ChunkInferenceTextInput("bar")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1247,7 +1247,7 @@ public void testChunkInferSetsTokenization() { service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar")), + List.of(new ChunkInferenceTextInput("foo"), new ChunkInferenceTextInput("bar")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1299,7 +1299,7 @@ public void testChunkInfer_FailsBatch() throws InterruptedException { service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar"), new ChunkInferenceInput("baz")), + List.of(new ChunkInferenceTextInput("foo"), new ChunkInferenceTextInput("bar"), new ChunkInferenceTextInput("baz")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1325,7 +1325,7 @@ public void testChunkingLargeDocument() throws InterruptedException { // build a doc with enough words to make numChunks of chunks int wordsPerChunk = 10; int numWords = numChunks * wordsPerChunk; - var input = new ChunkInferenceInput("word ".repeat(numWords), null); + var input = new ChunkInferenceTextInput("word ".repeat(numWords), null); Client client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); @@ -2034,7 +2034,7 @@ public void test_nullTimeoutUsesClusterSetting() throws InterruptedException { var latch = new CountDownLatch(1); var latchedListener = new LatchedActionListener<>(ActionListener.noop(), latch); - service.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, null, latchedListener); + service.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, null, latchedListener, null); assertTrue(latch.await(30, TimeUnit.SECONDS)); @@ -2078,7 +2078,19 @@ public void test_providedTimeoutPropagateProperly() throws InterruptedException var latch = new CountDownLatch(1); var latchedListener = new LatchedActionListener<>(ActionListener.noop(), latch); - service.infer(model, null, null, null, List.of("test input"), false, Map.of(), InputType.SEARCH, providedTimeout, latchedListener); + service.infer( + model, + null, + null, + null, + List.of("test input"), + false, + Map.of(), + InputType.SEARCH, + providedTimeout, + latchedListener, + null + ); assertTrue(latch.await(30, TimeUnit.SECONDS)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 65ccc52cefc30..0212d3bc8d00c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -39,7 +40,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -671,7 +672,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOEx new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -711,7 +713,8 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForModelThatD new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ) ); assertThat( @@ -788,7 +791,8 @@ public void testInfer_SendsCompletionRequest() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -847,7 +851,8 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -894,7 +899,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbeddingsModel model) throws IOException { - var input = List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")); + List input = List.of(new ChunkInferenceTextInput("a"), new ChunkInferenceTextInput("bb")); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -939,12 +944,12 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertEquals(new ChunkedInference.TextOffset(0, input.get(0).getInput().length()), floatResult.chunks().get(0).offset()); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.0123f, -0.0123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -954,12 +959,12 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertEquals(new ChunkedInference.TextOffset(0, input.get(1).getInput().length()), floatResult.chunks().get(0).offset()); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.0456f, -0.0456f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -978,7 +983,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed "model", Strings.format("%s/%s", "models", modelId), "content", - Map.of("parts", List.of(Map.of("text", input.get(0).input()))), + Map.of("parts", List.of(Map.of("text", input.get(0).getInput()))), "taskType", "RETRIEVAL_DOCUMENT" ), @@ -986,7 +991,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed "model", Strings.format("%s/%s", "models", modelId), "content", - Map.of("parts", List.of(Map.of("text", input.get(1).input()))), + Map.of("parts", List.of(Map.of("text", input.get(1).getInput()))), "taskType", "RETRIEVAL_DOCUMENT" ) @@ -1022,7 +1027,8 @@ public void testInfer_ResourceNotFound() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java index eca4a369c29c8..6bfb33769602f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java @@ -9,7 +9,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -36,12 +36,12 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { } """; - TextEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F))))); + assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F))))); } public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { @@ -64,7 +64,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException } """; - TextEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -73,8 +73,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException parsedResults.embeddings(), is( List.of( - TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)), - TextEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F)) + DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)), + DenseEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F)) ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java index 8f19edb3031d7..c3a8a8b74078b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java @@ -9,7 +9,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -42,12 +42,12 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { } """; - TextEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F))))); + assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F))))); } public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { @@ -82,7 +82,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException } """; - TextEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -91,8 +91,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException parsedResults.embeddings(), is( List.of( - TextEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F)), - TextEmbeddingFloatResults.Embedding.of(List.of(-0.456F, 0.456F)) + DenseEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F)), + DenseEmbeddingFloatResults.Embedding.of(List.of(-0.456F, 0.456F)) ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index 7acbe17340757..f6013277688b7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -73,7 +73,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOExcep new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index f9a67e7216e75..3eda602d578e6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InputType; @@ -97,7 +97,7 @@ public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOE service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("abc")), + List.of(new ChunkInferenceTextInput("abc")), new HashMap<>(), InputType.INTERNAL_SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index b28a6e6636c70..1315aaa05d90a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; @@ -42,8 +42,8 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -272,7 +272,8 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -563,7 +564,8 @@ private InferenceEventsAssertion streamCompletion() throws Exception { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); @@ -1037,7 +1039,8 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1076,7 +1079,8 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ) ); assertThat( @@ -1112,7 +1116,8 @@ public void testInfer_SendsElserRequest() throws IOException { new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1200,7 +1205,7 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("abc")), + List.of(new ChunkInferenceTextInput("abc")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1212,10 +1217,10 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th var embeddingResult = (ChunkedInferenceEmbedding) result; assertThat(embeddingResult.chunks(), hasSize(1)); assertThat(embeddingResult.chunks().get(0).offset(), is(new ChunkedInference.TextOffset(0, "abc".length()))); - assertThat(embeddingResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(embeddingResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { -0.0123f, 0.0123f }, - ((TextEmbeddingFloatResults.Embedding) embeddingResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) embeddingResult.chunks().get(0).embedding()).values(), 0.001f ); assertThat(webServer.requests(), hasSize(1)); @@ -1252,7 +1257,7 @@ public void testChunkedInfer() throws IOException { service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("abc")), + List.of(new ChunkInferenceTextInput("abc")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1266,10 +1271,10 @@ public void testChunkedInfer() throws IOException { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 3), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java index 61e035326d163..e157038f2f244 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java @@ -10,7 +10,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.common.ParsingException; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -32,14 +32,14 @@ public void testFromResponse_CreatesResultsForASingleItem_ArrayFormat() throws I ] """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) ); } @@ -55,14 +55,14 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws } """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) ); } @@ -80,7 +80,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws ] """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -89,8 +89,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws parsedResults.embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) ) ) ); @@ -112,7 +112,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw } """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -121,8 +121,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw parsedResults.embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) ) ) ); @@ -255,12 +255,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ArrayFormat() throw ] """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F })))); + assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F })))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() throws IOException { @@ -274,12 +274,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() thro } """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F })))); + assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F })))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() throws IOException { @@ -291,12 +291,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() thro ] """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))); + assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() throws IOException { @@ -310,12 +310,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() thr } """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))); + assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))); } public void testFromResponse_FailsWhenEmbeddingValueIsAnObject_ObjectFormat() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 7bd68f0ba0510..3b5aa8246b768 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -38,7 +39,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -610,7 +611,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOExcept new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -651,7 +653,8 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ) ); MatcherAssert.assertThat( @@ -709,7 +712,8 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -732,7 +736,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { } private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws IOException { - var input = List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")); + List input = List.of(new ChunkInferenceTextInput("a"), new ChunkInferenceTextInput("bb")); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -787,12 +791,12 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertEquals(new ChunkedInference.TextOffset(0, input.get(0).getInput().length()), floatResult.chunks().get(0).offset()); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.0123f, -0.0123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -802,12 +806,12 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertEquals(new ChunkedInference.TextOffset(0, input.get(1).getInput().length()), floatResult.chunks().get(0).offset()); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.0456f, -0.0456f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -854,7 +858,8 @@ public void testInfer_ResourceNotFound() throws IOException { new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java index db42e9c49e1e5..c3bdb8c686985 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java @@ -9,7 +9,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -36,12 +36,12 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { } """; - TextEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F))))); + assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F))))); } public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { @@ -66,7 +66,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException } """; - TextEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -75,8 +75,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException parsedResults.embeddings(), is( List.of( - TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)), - TextEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F)) + DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)), + DenseEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F)) ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index fc50acdbd39b6..f67bba48c0d3c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; @@ -38,7 +38,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -792,7 +792,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -878,7 +879,8 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -912,7 +914,8 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -970,7 +973,8 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1041,7 +1045,8 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { new HashMap<>(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1096,7 +1101,8 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { new HashMap<>(), InputType.CLUSTERING, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1167,7 +1173,8 @@ public void testInfer_Embedding_Get_Response_NullInputType() throws IOException new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1226,7 +1233,8 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1311,7 +1319,8 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1408,7 +1417,8 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1491,7 +1501,8 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1587,7 +1598,8 @@ public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1677,7 +1689,7 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + List.of(new ChunkInferenceTextInput("a"), new ChunkInferenceTextInput("bb")), new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1691,10 +1703,10 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1703,10 +1715,10 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java index c1b19cb450789..4df61cdc439af 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java @@ -11,9 +11,9 @@ import org.elasticsearch.common.ParsingException; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; @@ -69,10 +69,10 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class)); + assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class)); assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) ); } @@ -123,13 +123,13 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class)); + assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class)); assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) ) ) ); @@ -363,8 +363,8 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBinary() throws IOExcep ); assertThat( - ((TextEmbeddingBitResults) parsedResults).embeddings(), - is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) + ((DenseEmbeddingBitResults) parsedResults).embeddings(), + is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) ); } @@ -411,8 +411,8 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBit() throws IOExceptio ); assertThat( - ((TextEmbeddingBitResults) parsedResults).embeddings(), - is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) + ((DenseEmbeddingBitResults) parsedResults).embeddings(), + is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) ); } @@ -504,7 +504,7 @@ public void testFieldsInDifferentOrderServer() throws IOException { } }"""; - TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) JinaAIEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) JinaAIEmbeddingsResponseEntity.fromResponse( JinaAIEmbeddingsRequestTests.createRequest( List.of("abc"), InputTypeTests.randomWithNull(), @@ -525,9 +525,9 @@ public void testFieldsInDifferentOrderServer() throws IOException { parsedResults.embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java index 243235211a7de..0c406ad35054e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -42,7 +42,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -677,7 +677,7 @@ public void testChunkedInfer(LlamaEmbeddingsModel model) throws IOException { service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("abc"), new ChunkInferenceInput("def")), + List.of(new ChunkInferenceTextInput("abc"), new ChunkInferenceTextInput("def")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -691,11 +691,11 @@ public void testChunkedInfer(LlamaEmbeddingsModel model) throws IOException { assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.010060793f, -0.0017529363f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -703,11 +703,11 @@ public void testChunkedInfer(LlamaEmbeddingsModel model) throws IOException { assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.110060793f, -0.1017529363f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -804,7 +804,8 @@ private InferenceEventsAssertion streamCompletion() throws Exception { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 50731811e4164..ff8037206e754 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; @@ -41,7 +41,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.ModelConfigurationsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -262,7 +262,8 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -435,7 +436,8 @@ private InferenceEventsAssertion streamCompletion() throws Exception { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); @@ -1003,7 +1005,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws I new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -1044,7 +1047,8 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ) ); assertThat( @@ -1123,7 +1127,7 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("abc"), new ChunkInferenceInput("def")), + List.of(new ChunkInferenceTextInput("abc"), new ChunkInferenceTextInput("def")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1137,11 +1141,11 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -1149,11 +1153,11 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -1202,7 +1206,8 @@ public void testInfer_UnauthorisedResponse() throws IOException { new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 676dca2778141..28dc3b4f6ff9d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -21,7 +21,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; @@ -41,7 +41,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -876,7 +876,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException new HashMap<>(), InputType.INTERNAL_SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -916,7 +917,8 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ) ); assertThat( @@ -953,7 +955,8 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -994,7 +997,8 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -1056,7 +1060,8 @@ public void testInfer_SendsRequest() throws IOException { new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1297,7 +1302,8 @@ private InferenceEventsAssertion streamCompletion() throws Exception { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); @@ -1443,7 +1449,8 @@ public void testInfer_UnauthorisedResponse() throws IOException { new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -1539,7 +1546,7 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + List.of(new ChunkInferenceTextInput("a"), new ChunkInferenceTextInput("bb")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1553,11 +1560,11 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -1566,11 +1573,11 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java index f2a430eefc801..77e1f384509b4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java @@ -10,7 +10,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentParseException; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -45,14 +45,14 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { } """; - TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) ); } @@ -86,7 +86,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException } """; - TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -95,8 +95,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException parsedResults.embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) ) ) ); @@ -254,12 +254,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio } """; - TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F })))); + assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F })))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException { @@ -283,12 +283,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti } """; - TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))); + assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))); } public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { @@ -365,7 +365,7 @@ public void testFieldsInDifferentOrderServer() throws IOException { } }"""; - TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8)) ); @@ -374,9 +374,9 @@ public void testFieldsInDifferentOrderServer() throws IOException { parsedResults.embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java index 5d6bec1bcfbff..12201a51dc00e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -50,6 +50,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; import java.util.stream.Stream; import static org.elasticsearch.action.ActionListener.assertOnce; @@ -138,7 +139,8 @@ public void testInferWithWrongModel() { null, INPUT_TYPE, THIRTY_SECONDS, - assertUnsupportedModel() + assertUnsupportedModel(), + null ); } @@ -186,7 +188,8 @@ public void testInfer() { verify(schemas, only()).schemaFor(eq(model)); verify(schema, times(1)).request(eq(model), assertRequest()); verify(schema, times(1)).response(eq(model), any(), any()); - }) + }), + null ); verify(client, only()).invoke(any(), any(), any(), any()); verifyNoMoreInteractions(client, schemas, schema); @@ -213,7 +216,7 @@ public void test_nullTimeoutUsesClusterSetting() throws InterruptedException { var latch = new CountDownLatch(1); var latchedListener = new LatchedActionListener<>(ActionListener.noop(), latch); - service.infer(model, QUERY, null, null, INPUT, false, null, InputType.SEARCH, null, latchedListener); + service.infer(model, QUERY, null, null, INPUT, false, null, InputType.SEARCH, null, latchedListener, null); assertTrue(latch.await(30, TimeUnit.SECONDS)); assertEquals(configuredTimeout, capturedTimeout.get()); @@ -240,7 +243,7 @@ public void test_providedTimeoutPropagateProperly() throws InterruptedException var latch = new CountDownLatch(1); var latchedListener = new LatchedActionListener<>(ActionListener.noop(), latch); - service.infer(model, QUERY, null, null, INPUT, false, null, InputType.SEARCH, providedTimeout, latchedListener); + service.infer(model, QUERY, null, null, INPUT, false, null, InputType.SEARCH, providedTimeout, latchedListener, null); assertTrue(latch.await(30, TimeUnit.SECONDS)); assertEquals(providedTimeout, capturedTimeout.get()); @@ -284,7 +287,7 @@ public void testInferStream() { verify(schemas, only()).streamSchemaFor(eq(model)); verify(schema, times(1)).streamRequest(eq(model), assertRequest()); verify(schema, times(1)).streamResponse(eq(model), any()); - })); + }), null); verify(client, only()).invokeStream(any(), any(), any(), any()); verifyNoMoreInteractions(client, schemas, schema); } @@ -321,7 +324,8 @@ public void testInferError() { verify(schemas, only()).schemaFor(eq(model)); verify(schema, times(1)).request(eq(model), assertRequest()); verify(schema, times(1)).error(eq(model), assertArg(e -> assertThat(e, equalTo(expectedException)))); - }) + }), + null ); verify(client, only()).invoke(any(), any(), any(), any()); verifyNoMoreInteractions(client, schemas, schema); @@ -349,7 +353,7 @@ public void testInferException() { assertThat(e, isA(ElasticsearchStatusException.class)); assertThat(((ElasticsearchStatusException) e).status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR)); assertThat(e.getMessage(), equalTo("Failed to call SageMaker for inference id [some id].")); - })); + }), null); verify(client, only()).invokeStream(any(), any(), any(), any()); verifyNoMoreInteractions(client, schemas, schema); } @@ -437,7 +441,7 @@ public void testChunkedInferWithWrongModel() { sageMakerService.chunkedInfer( mockUnsupportedModel(), QUERY, - INPUT.stream().map(ChunkInferenceInput::new).toList(), + INPUT.stream().map(ChunkInferenceTextInput::new).collect(Collectors.toList()), null, INPUT_TYPE, THIRTY_SECONDS, @@ -458,7 +462,7 @@ public void testChunkedInfer() throws Exception { sageMakerService.chunkedInfer( model, QUERY, - expectedInput.stream().map(ChunkInferenceInput::new).toList(), + expectedInput.stream().map(ChunkInferenceTextInput::new).collect(Collectors.toList()), null, INPUT_TYPE, THIRTY_SECONDS, @@ -508,7 +512,7 @@ public void testChunkedInferError() { sageMakerService.chunkedInfer( model, QUERY, - expectedInput.stream().map(ChunkInferenceInput::new).toList(), + expectedInput.stream().map(ChunkInferenceTextInput::new).collect(Collectors.toList()), null, INPUT_TYPE, THIRTY_SECONDS, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java index ed0ee43266ab5..3ac0f20d4242c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java @@ -9,8 +9,8 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema; @@ -66,8 +66,8 @@ public void testBitResponse() throws Exception { assertThat(bitResults.embeddings().size(), is(1)); var embedding = bitResults.embeddings().get(0); - assertThat(embedding, isA(TextEmbeddingByteResults.Embedding.class)); - assertThat(((TextEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 })); + assertThat(embedding, isA(DenseEmbeddingByteResults.Embedding.class)); + assertThat(((DenseEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 })); } public void testByteResponse() throws Exception { @@ -87,8 +87,8 @@ public void testByteResponse() throws Exception { assertThat(byteResults.embeddings().size(), is(1)); var embedding = byteResults.embeddings().get(0); - assertThat(embedding, isA(TextEmbeddingByteResults.Embedding.class)); - assertThat(((TextEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 })); + assertThat(embedding, isA(DenseEmbeddingByteResults.Embedding.class)); + assertThat(((DenseEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 })); } public void testFloatResponse() throws Exception { @@ -108,7 +108,7 @@ public void testFloatResponse() throws Exception { assertThat(byteResults.embeddings().size(), is(1)); var embedding = byteResults.embeddings().get(0); - assertThat(embedding, isA(TextEmbeddingFloatResults.Embedding.class)); - assertThat(((TextEmbeddingFloatResults.Embedding) embedding).values(), is(new float[] { 0.1F })); + assertThat(embedding, isA(DenseEmbeddingFloatResults.Embedding.class)); + assertThat(((DenseEmbeddingFloatResults.Embedding) embedding).values(), is(new float[] { 0.1F })); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java index 35b78b004618c..d498542eba5d4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java @@ -12,7 +12,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayloadTestCase; @@ -149,7 +149,7 @@ public void testResponse() throws Exception { assertThat( textEmbeddingFloatResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java index 30e7a33757c16..c9db9907792fd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java @@ -16,7 +16,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandEmbeddingModel; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings; @@ -118,7 +118,7 @@ public void testValidate_ElandTextEmbeddingModelValidationFails() { public void testValidate_ElandTextEmbeddingModelValidationSucceedsAndDimensionsSetByUserValid() { var dimensions = randomIntBetween(1, 10); - var mockInferenceServiceResults = mock(TextEmbeddingResults.class); + var mockInferenceServiceResults = mock(DenseEmbeddingResults.class); var mockUpdatedModel = mock(CustomElandEmbeddingModel.class); when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenReturn(dimensions); CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(true, dimensions); @@ -151,7 +151,7 @@ public void testValidate_ElandTextEmbeddingModelValidationSucceedsAndDimensionsS public void testValidate_ElandTextEmbeddingModelValidationSucceedsAndDimensionsSetByUserInvalid() { var dimensions = randomIntBetween(1, 10); - var mockInferenceServiceResults = mock(TextEmbeddingResults.class); + var mockInferenceServiceResults = mock(DenseEmbeddingResults.class); when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenReturn( randomValueOtherThan(dimensions, () -> randomIntBetween(1, 10)) ); @@ -207,7 +207,7 @@ public void testValidate_ElandTextEmbeddingAndValidationReturnsInvalidResultsTyp public void testValidate_ElandTextEmbeddingModelDimensionsNotSetByUser() { var dimensions = randomIntBetween(1, 10); - var mockInferenceServiceResults = mock(TextEmbeddingResults.class); + var mockInferenceServiceResults = mock(DenseEmbeddingResults.class); when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenReturn(dimensions); CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(false, null); @@ -239,7 +239,7 @@ public void testValidate_ElandTextEmbeddingModelDimensionsNotSetByUser() { } public void testValidate_ElandTextEmbeddingModelAndEmbeddingSizeRetrievalThrowsException() { - var mockInferenceServiceResults = mock(TextEmbeddingResults.class); + var mockInferenceServiceResults = mock(DenseEmbeddingResults.class); when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenThrow(ElasticsearchStatusException.class); CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(false, null); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java index b4d563b565eee..1353988fa56c6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java @@ -77,6 +77,7 @@ public void testCustomServiceValidator() { eq(Map.of()), eq(InputType.INTERNAL_INGEST), any(), + any(), any() ); verifyNoMoreInteractions(mockService); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java index 6faa4bd07b6aa..a940bf87d1ff0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java @@ -71,6 +71,7 @@ public void testValidate_ServiceThrowsException() { eq(Map.of()), eq(InputType.INTERNAL_INGEST), eq(TIMEOUT), + any(), any() ); @@ -114,6 +115,7 @@ private void mockSuccessfulCallToService(String query, InferenceServiceResults r eq(Map.of()), eq(InputType.INTERNAL_INGEST), eq(TIMEOUT), + any(), any() ); @@ -132,6 +134,7 @@ private void verifyCallToService(boolean withQuery) { eq(Map.of()), eq(InputType.INTERNAL_INGEST), eq(TIMEOUT), + any(), any() ); verifyNoMoreInteractions(mockInferenceService, mockModel, mockActionListener, mockInferenceServiceResults); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java index 45726f0789667..192a47233befb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java @@ -16,10 +16,10 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.EmptyTaskSettingsTests; import org.elasticsearch.xpack.inference.ModelConfigurationsTests; import org.junit.Before; @@ -94,7 +94,7 @@ public void testValidate_ServiceReturnsNonTextEmbeddingResults() { } public void testValidate_RetrievingEmbeddingSizeThrowsIllegalStateException() { - TextEmbeddingFloatResults results = new TextEmbeddingFloatResults(List.of()); + DenseEmbeddingFloatResults results = new DenseEmbeddingFloatResults(List.of()); when(mockServiceSettings.dimensionsSetByUser()).thenReturn(true); when(mockServiceSettings.dimensions()).thenReturn(randomNonNegativeInt()); @@ -107,7 +107,7 @@ public void testValidate_RetrievingEmbeddingSizeThrowsIllegalStateException() { } public void testValidate_DimensionsSetByUserDoNotEqualEmbeddingSize() { - TextEmbeddingByteResults results = TextEmbeddingByteResultsTests.createRandomResults(); + DenseEmbeddingByteResults results = TextEmbeddingByteResultsTests.createRandomResults(); var dimensions = randomValueOtherThan(results.getFirstEmbeddingSize(), ESTestCase::randomNonNegativeInt); when(mockServiceSettings.dimensionsSetByUser()).thenReturn(true); @@ -131,7 +131,7 @@ public void testValidate_DimensionsNotSetByUser() { } private void mockSuccessfulValidation(Boolean dimensionsSetByUser) { - TextEmbeddingByteResults results = TextEmbeddingByteResultsTests.createRandomResults(); + DenseEmbeddingByteResults results = TextEmbeddingByteResultsTests.createRandomResults(); when(mockModel.getConfigurations()).thenReturn(ModelConfigurationsTests.createRandomInstance()); when(mockModel.getTaskSettings()).thenReturn(EmptyTaskSettingsTests.createRandom()); when(mockServiceSettings.dimensionsSetByUser()).thenReturn(dimensionsSetByUser); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 8cad8cbad208a..492272998031f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkInferenceTextInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; @@ -37,7 +37,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -732,7 +732,8 @@ public void testInfer_ThrowsErrorWhenModelIsNotVoyageAIModel() throws IOExceptio new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -780,7 +781,8 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept new HashMap<>(), InputType.CLUSTERING, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ) ); MatcherAssert.assertThat( @@ -862,7 +864,8 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -896,7 +899,8 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -952,7 +956,8 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1034,7 +1039,8 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { new HashMap<>(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1116,7 +1122,8 @@ public void testInfer_Embedding_Get_Response_NullInputType() throws IOException new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1179,7 +1186,8 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1267,7 +1275,8 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1361,7 +1370,8 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1439,7 +1449,8 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1535,7 +1546,8 @@ public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, - listener + listener, + null ); var result = listener.actionGet(TIMEOUT); @@ -1638,7 +1650,7 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + List.of(new ChunkInferenceTextInput("a"), new ChunkInferenceTextInput("bb")), new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1652,10 +1664,10 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo var floatResult = (ChunkedInferenceEmbedding) results.getFirst(); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().getFirst().offset()); - assertThat(floatResult.chunks().get(0).embedding(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), CoreMatchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1664,10 +1676,13 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().getFirst().offset()); - assertThat(floatResult.chunks().getFirst().embedding(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat( + floatResult.chunks().getFirst().embedding(), + CoreMatchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class) + ); assertArrayEquals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java index 80a00737ddf52..7fa6b0c7bece3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java @@ -11,7 +11,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentParseException; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.services.voyageai.request.VoyageAIEmbeddingsRequest; @@ -60,8 +60,8 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { ); assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) ); } @@ -106,11 +106,11 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException ); assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) ) ) ); @@ -299,8 +299,8 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio ); assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F }))) + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F }))) ); } @@ -336,8 +336,8 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti ); assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))) + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))) ); } @@ -427,15 +427,15 @@ public void testFieldsInDifferentOrderServer() throws IOException { new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class)); + assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class)); assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) ) ) );