[ML] Add Embedding inference task type#138198
Conversation
This commit introduces a new TaskType for the inference plugin named
"embedding" which allows both images and text to be used as inputs to
create dense vectors. This task invokes the new
InferenceService.embeddingInfer() method which is currently implemented
to throw an UnsupportedOperationException if called on non-test
implementations of InferenceService.
The input for this task is specified using a list
of "content" objects, each of which specifies the type of input ("text"
or "image_base64") and the String value of the input:
"input": [
{
"content": {"type": "image_base64", "value": "image data"},
},
{
"content": [
{"type": "text", "value": "text input"},
{"type": "image_base64", "value": "image data"}
]
}
]
It is also possible to specify a single content object rather than a
list:
"input": {
"content": {"type": "text", "value": "text input"}
}
Each content object in the request will result in a single embedding
array in the response, meaning that multiple texts or images can be
grouped into a single content object and converted into a single
embedding array if the model being used supports that.
To preserve input compatibility with the existing text_embedding task,
the input can also be specified as a single String or a list of Strings,
each of which will be internally parsed as a content object with type
"text":
"input": "singe text input"
OR
"input": ["first text input", "second text input"]
The output format for this task uses the GenericDenseEmbedding*Results
classes which produce a format like the one below (with additional
"_bytes" or "_bits" suffixes on the "embeddings" array name if those
element types are specified).
{
embeddings=[
{
embedding=[54.0, ... 51.0]
},
{
embedding=[45.0, ... 56.0]
}
]
}
An InferenceStringGroup class is introduced to represent content objects
with more than one text or image input, which is necessary to support
the use case of multiple inputs generating a single embedding. This
class is redundant right now, because all code paths assume a 1:1
mapping between number of inputs and numebr of embeddings generated, but
the implementation is designed to support this use case being added in
future.
This commit also adds testing for the new task type and updates existing
tests and code to properly handle the new task type.
|
Pinging @elastic/ml-core (Team:ML) |
|
Hi @DonalEvans, I've created a changelog YAML for you. |
- In TransportInferenceUsageAction, do not return ModelStats for the embedding task type if some nodes in the cluster do not support the embedding task - Prevent creation of inference endpoints with the embedding task type if some nodes in the cluster do not support the embedding task - Add tests for the new behaviour
...nce/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java
Outdated
Show resolved
Hide resolved
...ce/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java
Outdated
Show resolved
Hide resolved
| case EMBEDDING -> sendEmbeddingRequest(request, l); | ||
| case null, default -> sendInferenceActionRequest(request, l); |
There was a problem hiding this comment.
I wonder if there's an opportunity for us to unify this embedding task type into the existing InferenceAction.Request, so we don't need this special handling here.
There was a problem hiding this comment.
The InferenceAction.Request class is already a bit bloated with fields that are only used by certain task types, so adding yet another constructor argument to it to allow us to pass in an EmbeddingRequest object feels like a step in the wrong direction to me. It would remove the need for special handling in TransportInferenceActionProxy, but add much more more special handling to the InferenceAction.Request class, as well as causing knock-on changes in all the places that class is used.
| if (taskType == TaskType.ANY | ||
| || (taskType == TaskType.EMBEDDING | ||
| && featureService.clusterHasFeature(clusterService.state(), EMBEDDING_TASK_TYPE) == false)) { |
There was a problem hiding this comment.
are we doing this only for the inference_usage.yml tests?
then I think we don't need to, we can just add the cluster feature here:
@dimitris-athanasiou - you added the extended usage, would it be okay if we just use the new cluster feature in yaml test?
There was a problem hiding this comment.
This isn't just for tests; if a user is in the middle of a rolling upgrade of their cluster and they request usage stats, then without this check, the request may fail.
|
|
||
| var requestAsMap = requestToMap(request); | ||
| var resolvedTaskType = ServiceUtils.resolveTaskType(request.getTaskType(), (String) requestAsMap.remove(TaskType.NAME)); | ||
| if (resolvedTaskType == TaskType.EMBEDDING && featureService.clusterHasFeature(state, EMBEDDING_TASK_TYPE) == false) { |
There was a problem hiding this comment.
why do we need to do this? is it again only for the inference_usage tests when they run in bwc mode?
then I think we can just check the cluster_feature directly in the yaml test and we don't need this check here.
when we added new task types in the past (e.g. #119982) we never needed a similar check
There was a problem hiding this comment.
In the past, if a user tried to create an endpoint with a new task type in the middle of a rolling upgrade, the request could still fail, but with a more obscure message. This check just makes it more obvious to the user what the issue is and what they should do about it. As far as I know, we don't have any tests that try to create an inference endpoint with a new task type in the middle of a rolling upgrade, which is why no check was added for previous new task types. I still think it's good to have though, from the user's point of view.
- Remove streaming field from EmbeddingAction.Request - Adjust test to handle EMBEDDING task type
- Add SimpleEmbeddingServiceIntegrationValidator that's used when the task type is embedding - Add tests for new class - Add test case for embedding task to InferenceGetServicesIT
ioanatia
left a comment
There was a problem hiding this comment.
I'd like to get more eyes on this from @jonathan-buttner or maybe @timgrein that have more context on the inference API.
One thing I want to note is that at this stage it is not necessary IMO to support both formats:
"input": [
{
"content": {"type": "image_base64", "value": "image data"},
},
{
"content": [
{"type": "text", "value": "text input"},
{"type": "image_base64", "value": "image data"}
]
}
]
and
"input": {
"content": {"type": "text", "value": "text input"}
}
I would just go with the first one for now since it is the more generic one.
It is a bit premature to support both at this stage.
But I don't want to block progress here - we can discuss this separately.
It would actually be slightly more complicated to support only a list of content objects but not a single content object in terms of the parsing of the request. Is there a specific reason why we don't want to support a more flexible request format? |
jonathan-buttner
left a comment
There was a problem hiding this comment.
Great work Donal. I left a few questions and suggests around tests and validation. Looks great though.
| namedWriteables.addAll(writeables); | ||
| } | ||
|
|
||
| private static void addEmbeddingNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) { |
There was a problem hiding this comment.
Hmm do we need these as named writeables?
I think named writeables are useful when we're writing/reading an interface and we don't know the underlying type.
Here's an example: https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java#L122-L123
Could these just be Writeables?
| * </pre> | ||
| * @param inferenceStrings the list of {@link InferenceString} which should result in generating a single embedding vector | ||
| */ | ||
| public record InferenceStringGroup(List<InferenceString> inferenceStrings) implements NamedWriteable, ToXContentObject { |
There was a problem hiding this comment.
I probably missed them, but do we have tests for this file? Could we create a file that extends abstract bwc for it?
There was a problem hiding this comment.
You didn't miss them, I just completely forgot to add them 🤦
Thanks for the reminder!
|
|
||
| import static org.hamcrest.Matchers.is; | ||
|
|
||
| public class EmbeddingRequestTests extends AbstractWireSerializingTestCase<EmbeddingRequest> { |
There was a problem hiding this comment.
Can we use the bwc ones here? That way if/when we add transport version checks we don't forget to switch it then.
| } | ||
|
|
||
| public static DataType fromString(String name) { | ||
| return valueOf(name.trim().toUpperCase(Locale.ROOT)); |
There was a problem hiding this comment.
I believe if the user provides an invalid string we'll get an IllegalArgumentException here which is good, but what does the error message include? Is it helpful?
I see we have a test that looks for: [InferenceString] failed to parse field [type]
Could we throw a different IllegalArgumentException that includes the possible values or something that are valid?
Something like:
Unrecognized type [%s], must be one of %s
There was a problem hiding this comment.
The exception with the [InferenceString] failed to parse field [type] message comes from the ObjectParser, and wraps the exception thrown from the valueOf() call, so I'm not 100% sure if any custom exception we throw from fromString() will actually be visible to a user, but it's easy enough to do so there's no reason not to use a more descriptive message.
| import static org.hamcrest.Matchers.is; | ||
|
|
||
| public class EmbeddingRequestTests extends AbstractWireSerializingTestCase<EmbeddingRequest> { | ||
|
|
There was a problem hiding this comment.
Can we add a test without input_type, since it is optional?
| import static org.hamcrest.Matchers.containsString; | ||
| import static org.hamcrest.Matchers.is; | ||
|
|
||
| public class InferenceStringTests extends AbstractWireSerializingTestCase<InferenceString> { |
There was a problem hiding this comment.
How about we make this one bwc as well?
| import java.io.IOException; | ||
| import java.util.Objects; | ||
|
|
||
| public class EmbeddingAction extends ActionType<InferenceAction.Response> { |
There was a problem hiding this comment.
I might have missed them but could we create some bwc tests for EmbeddingAction.Request?
- Make InferenceString and InferenceStringGroup Writeable instead of NamedWriteable - Add missing test classes - Convert existing tests to backwards compatibility tests - Provide more descriptive exception message on parser error
- Rename IMAGE_BASE64 to IMAGE - Introduce DataFormat enum in InferenceString - Add and update tests
This commit introduces a new
TaskTypefor the inference plugin named "embedding" which allows both images and text to be used as inputs to create dense vectors. This task invokes the newInferenceService.embeddingInfer()method which is currently implemented to throw anUnsupportedOperationExceptionif called on non-test implementations ofInferenceService.The input for this task is specified using a list of "content" objects, each of which specifies the type of input ("text" or "image"), the format of the input ("text" or "base64") and the String value of the input. The "format" field is optional and defaults to "text" for the "text" input type and "base64" for the "image" input type:
It is also possible to specify a single content object rather than a list:
Each content object in the request will result in a single embedding array in the response, meaning that multiple texts or images can be grouped into a single content object and converted into a single embedding array if the model being used supports that.
To preserve input compatibility with the existing
text_embeddingtask, the input can also be specified as a single String or a list of Strings, each of which will be internally parsed as a content object with type "text" and format "text":OR
The output format for this task uses the
GenericDenseEmbedding*Resultsclasses which produce a format like the one below (with additional "_bytes" or "_bits" suffixes on the "embeddings" array name if those element types are specified).An
InferenceStringGroupclass is introduced to represent content objects with more than one text or image input, which is necessary to support the use case of multiple inputs generating a single embedding. This class is redundant right now, because all code paths assume a 1:1 mapping between number of inputs and number of embeddings generated, but the implementation is designed to support this use case being added in future.This commit also adds testing for the new task type and updates existing tests and code to properly handle the new task type.