Skip to content

[ML] Add Embedding inference task type#138198

Merged
DonalEvans merged 20 commits intoelastic:mainfrom
DonalEvans:embedding-task-type
Dec 3, 2025
Merged

[ML] Add Embedding inference task type#138198
DonalEvans merged 20 commits intoelastic:mainfrom
DonalEvans:embedding-task-type

Conversation

@DonalEvans
Copy link
Contributor

@DonalEvans DonalEvans commented Nov 17, 2025

This commit introduces a new TaskType for the inference plugin named "embedding" which allows both images and text to be used as inputs to create dense vectors. This task invokes the new InferenceService.embeddingInfer() method which is currently implemented to throw an UnsupportedOperationException if called on non-test implementations of InferenceService.

The input for this task is specified using a list of "content" objects, each of which specifies the type of input ("text" or "image"), the format of the input ("text" or "base64") and the String value of the input. The "format" field is optional and defaults to "text" for the "text" input type and "base64" for the "image" input type:

"input": [
  {
    "content": {"type": "image", "format": "base64", "value": "image data"},
  },
  {
    "content": [
      {"type": "text", "value": "text input"},
      {"type": "image", "value": "image data"}
    ]
  }
]

It is also possible to specify a single content object rather than a list:

"input": {
  "content": {"type": "text", "format": "text", "value": "text input"}
}

Each content object in the request will result in a single embedding array in the response, meaning that multiple texts or images can be grouped into a single content object and converted into a single embedding array if the model being used supports that.

To preserve input compatibility with the existing text_embedding task, the input can also be specified as a single String or a list of Strings, each of which will be internally parsed as a content object with type "text" and format "text":

"input": "single text input"

OR

"input": ["first text input", "second text input"]

The output format for this task uses the GenericDenseEmbedding*Results classes which produce a format like the one below (with additional "_bytes" or "_bits" suffixes on the "embeddings" array name if those element types are specified).

{
  embeddings=[
    {
      embedding=[54.0, ... 51.0]
    },
    {
      embedding=[45.0, ... 56.0]
    }
  ]
}

An InferenceStringGroup class is introduced to represent content objects with more than one text or image input, which is necessary to support the use case of multiple inputs generating a single embedding. This class is redundant right now, because all code paths assume a 1:1 mapping between number of inputs and number of embeddings generated, but the implementation is designed to support this use case being added in future.

This commit also adds testing for the new task type and updates existing tests and code to properly handle the new task type.

This commit introduces a new TaskType for the inference plugin named
"embedding" which allows both images and text to be used as inputs to
create dense vectors. This task invokes the new
InferenceService.embeddingInfer() method which is currently implemented
to throw an UnsupportedOperationException if called on non-test
implementations of InferenceService.

The input for this task is specified using a list
of "content" objects, each of which specifies the type of input ("text"
or "image_base64") and the String value of the input:

"input": [
  {
    "content": {"type": "image_base64", "value": "image data"},
  },
  {
    "content": [
      {"type": "text", "value": "text input"},
      {"type": "image_base64", "value": "image data"}
    ]
  }
]

It is also possible to specify a single content object rather than a
list:

"input": {
  "content": {"type": "text", "value": "text input"}
}

Each content object in the request will result in a single embedding
array in the response, meaning that multiple texts or images can be
grouped into a single content object and converted into a single
embedding array if the model being used supports that.

To preserve input compatibility with the existing text_embedding task,
the input can also be specified as a single String or a list of Strings,
each of which will be internally parsed as a content object with type
"text":

"input": "singe text input"

OR

"input": ["first text input", "second text input"]

The output format for this task uses the GenericDenseEmbedding*Results
classes which produce a format like the one below (with additional
"_bytes" or "_bits" suffixes on the "embeddings" array name if those
element types are specified).

{
  embeddings=[
    {
      embedding=[54.0, ... 51.0]
    },
    {
      embedding=[45.0, ... 56.0]
    }
  ]
}

An InferenceStringGroup class is introduced to represent content objects
with more than one text or image input, which is necessary to support
the use case of multiple inputs generating a single embedding. This
class is redundant right now, because all code paths assume a 1:1
mapping between number of inputs and numebr of embeddings generated, but
the implementation is designed to support this use case being added in
future.

This commit also adds testing for the new task type and updates existing
tests and code to properly handle the new task type.
@elasticsearchmachine elasticsearchmachine added needs:triage Requires assignment of a team area label v9.3.0 labels Nov 17, 2025
@DonalEvans DonalEvans added >enhancement :ml Machine learning Team:ML Meta label for the ML team and removed needs:triage Requires assignment of a team area label labels Nov 17, 2025
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/ml-core (Team:ML)

@elasticsearchmachine
Copy link
Collaborator

Hi @DonalEvans, I've created a changelog YAML for you.

- In TransportInferenceUsageAction, do not return ModelStats for the
  embedding task type if some nodes in the cluster do not support the
  embedding task
- Prevent creation of inference endpoints with the embedding task type
  if some nodes in the cluster do not support the embedding task
- Add tests for the new behaviour
Comment on lines +68 to +69
case EMBEDDING -> sendEmbeddingRequest(request, l);
case null, default -> sendInferenceActionRequest(request, l);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there's an opportunity for us to unify this embedding task type into the existing InferenceAction.Request, so we don't need this special handling here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The InferenceAction.Request class is already a bit bloated with fields that are only used by certain task types, so adding yet another constructor argument to it to allow us to pass in an EmbeddingRequest object feels like a step in the wrong direction to me. It would remove the need for special handling in TransportInferenceActionProxy, but add much more more special handling to the InferenceAction.Request class, as well as causing knock-on changes in all the places that class is used.

Comment on lines +195 to +197
if (taskType == TaskType.ANY
|| (taskType == TaskType.EMBEDDING
&& featureService.clusterHasFeature(clusterService.state(), EMBEDDING_TASK_TYPE) == false)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we doing this only for the inference_usage.yml tests?
then I think we don't need to, we can just add the cluster feature here:

@dimitris-athanasiou - you added the extended usage, would it be okay if we just use the new cluster feature in yaml test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't just for tests; if a user is in the middle of a rolling upgrade of their cluster and they request usage stats, then without this check, the request may fail.


var requestAsMap = requestToMap(request);
var resolvedTaskType = ServiceUtils.resolveTaskType(request.getTaskType(), (String) requestAsMap.remove(TaskType.NAME));
if (resolvedTaskType == TaskType.EMBEDDING && featureService.clusterHasFeature(state, EMBEDDING_TASK_TYPE) == false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to do this? is it again only for the inference_usage tests when they run in bwc mode?
then I think we can just check the cluster_feature directly in the yaml test and we don't need this check here.

when we added new task types in the past (e.g. #119982) we never needed a similar check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the past, if a user tried to create an endpoint with a new task type in the middle of a rolling upgrade, the request could still fail, but with a more obscure message. This check just makes it more obvious to the user what the issue is and what they should do about it. As far as I know, we don't have any tests that try to create an inference endpoint with a new task type in the middle of a rolling upgrade, which is why no check was added for previous new task types. I still think it's good to have though, from the user's point of view.

- Remove streaming field from EmbeddingAction.Request
- Adjust test to handle EMBEDDING task type
- Add SimpleEmbeddingServiceIntegrationValidator that's used when the
  task type is embedding
- Add tests for new class
- Add test case for embedding task to InferenceGetServicesIT
Copy link
Contributor

@ioanatia ioanatia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to get more eyes on this from @jonathan-buttner or maybe @timgrein that have more context on the inference API.

One thing I want to note is that at this stage it is not necessary IMO to support both formats:

"input": [
  {
    "content": {"type": "image_base64", "value": "image data"},
  },
  {
    "content": [
      {"type": "text", "value": "text input"},
      {"type": "image_base64", "value": "image data"}
    ]
  }
]

and

"input": {
  "content": {"type": "text", "value": "text input"}
}

I would just go with the first one for now since it is the more generic one.
It is a bit premature to support both at this stage.
But I don't want to block progress here - we can discuss this separately.

@DonalEvans DonalEvans requested a review from timgrein November 24, 2025 16:36
@DonalEvans
Copy link
Contributor Author

I would just go with the first one for now since it is the more generic one.
It is a bit premature to support both at this stage.

It would actually be slightly more complicated to support only a list of content objects but not a single content object in terms of the parsing of the request. Is there a specific reason why we don't want to support a more flexible request format?

@DonalEvans DonalEvans added the test-release Trigger CI checks against release build label Nov 24, 2025
Copy link
Contributor

@jonathan-buttner jonathan-buttner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work Donal. I left a few questions and suggests around tests and validation. Looks great though.

namedWriteables.addAll(writeables);
}

private static void addEmbeddingNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm do we need these as named writeables?

I think named writeables are useful when we're writing/reading an interface and we don't know the underlying type.

Here's an example: https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java#L122-L123

Could these just be Writeables?

* </pre>
* @param inferenceStrings the list of {@link InferenceString} which should result in generating a single embedding vector
*/
public record InferenceStringGroup(List<InferenceString> inferenceStrings) implements NamedWriteable, ToXContentObject {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably missed them, but do we have tests for this file? Could we create a file that extends abstract bwc for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You didn't miss them, I just completely forgot to add them 🤦
Thanks for the reminder!


import static org.hamcrest.Matchers.is;

public class EmbeddingRequestTests extends AbstractWireSerializingTestCase<EmbeddingRequest> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the bwc ones here? That way if/when we add transport version checks we don't forget to switch it then.

}

public static DataType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe if the user provides an invalid string we'll get an IllegalArgumentException here which is good, but what does the error message include? Is it helpful?

I see we have a test that looks for: [InferenceString] failed to parse field [type]

Could we throw a different IllegalArgumentException that includes the possible values or something that are valid?

Something like:

Unrecognized type [%s], must be one of %s

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exception with the [InferenceString] failed to parse field [type] message comes from the ObjectParser, and wraps the exception thrown from the valueOf() call, so I'm not 100% sure if any custom exception we throw from fromString() will actually be visible to a user, but it's easy enough to do so there's no reason not to use a more descriptive message.

import static org.hamcrest.Matchers.is;

public class EmbeddingRequestTests extends AbstractWireSerializingTestCase<EmbeddingRequest> {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test without input_type, since it is optional?

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;

public class InferenceStringTests extends AbstractWireSerializingTestCase<InferenceString> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we make this one bwc as well?

import java.io.IOException;
import java.util.Objects;

public class EmbeddingAction extends ActionType<InferenceAction.Response> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might have missed them but could we create some bwc tests for EmbeddingAction.Request?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing

- Make InferenceString and InferenceStringGroup Writeable instead of
  NamedWriteable
- Add missing test classes
- Convert existing tests to backwards compatibility tests
- Provide more descriptive exception message on parser error
@DonalEvans DonalEvans removed the test-release Trigger CI checks against release build label Dec 2, 2025
@DonalEvans DonalEvans merged commit 4e09464 into elastic:main Dec 3, 2025
34 checks passed
@DonalEvans DonalEvans deleted the embedding-task-type branch December 4, 2025 20:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

>enhancement :ml Machine learning Team:ML Meta label for the ML team v9.3.0

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants