[Inference API] Add support for embedding task to JinaAI service#140323
[Inference API] Add support for embedding task to JinaAI service#140323DonalEvans merged 10 commits intoelastic:mainfrom
Conversation
This commit adds support for the multimodal embedding task type to the JinaAi service. In order to enable this, the existing JinaAIEmbeddingsServiceSettings class has been split into two versions, one for text_embedding and one for embedding, with the common behaviour now found in the BaseJinaAIEmbeddingsServiceSettings class. The embedding task supports using models that accept multimodal inputs as well as models that only accept text inputs, so additional logic has been added to JinaAIEmbeddingsRequestEntity.toXContent() to allow the request sent to Jina to be structured appropriately based on the type of model being used. It is necessary to know whether a given list of inputs contains non-text values, both to ensure that the model being used can support multimodal inputs, and to prevent late chunking being applied, since that setting is not supported by JinaAI for multimodal inputs. To enable this, the InferenceStringGroup class now determines whether any of the InferenceString it contains are non-text values when constructed. The response format used by the embedding task differs slightly from the response format used by the text_embedding task, so changes were made to the JinaAIEmbeddingsResponseEntity class to allow the appropriate DenseEmbeddingResults implementation to be returned based on task type. In order to support per-request task settings, additional parsing logic and a new taskSettings field have been added to the EmbeddingRequest class. This should have been present when the EmbeddingRequest class was first introduced, but it was overlooked at the time. Other changes in this commit: - Consolidate transport version definitions instead of having the same transport version defined in multiple places for JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED and JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED - Add test coverage for new task type - Greatly expand and clean up existing tests for JinaAI model and service settings classes
|
Pinging @elastic/search-inference-team (Team:Search - Inference) |
|
Hi @DonalEvans, I've created a changelog YAML for you. |
There was a problem hiding this comment.
Pull request overview
This PR adds support for the multimodal embedding task type to the JinaAI service. The key changes include:
- Splitting the existing
JinaAIEmbeddingsServiceSettingsclass into two versions:JinaAITextEmbeddingServiceSettingsfor text_embedding andJinaAIEmbeddingServiceSettingsfor embedding - Adding support for multimodal inputs (text and images) in the embedding task
- Enhancing
InferenceStringGroupto track whether it contains non-text entries - Adding per-request task settings support to
EmbeddingRequest - Consolidating transport version definitions
- Extensive test coverage for the new functionality
Reviewed changes
Copilot reviewed 39 out of 40 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| BaseJinaAIEmbeddingsServiceSettings.java | New base class containing common behavior for both embedding service settings types |
| JinaAITextEmbeddingServiceSettings.java | New class for text_embedding task (renamed from original JinaAIEmbeddingsServiceSettings) |
| JinaAIEmbeddingServiceSettings.java | New class for multimodal embedding task with multimodal support |
| JinaAIService.java | Updated to support both TEXT_EMBEDDING and EMBEDDING task types, adds embeddingInfer method |
| InferenceStringGroup.java | Enhanced to track non-text entries, converted from record to class |
| EmbeddingRequest.java | Added taskSettings field for per-request configuration |
| JinaAIEmbeddingsRequestEntity.java | Updated to handle multimodal inputs with appropriate request structure |
| JinaAIEmbeddingsResponseEntity.java | Modified to return correct results type based on task type |
| ServiceSettings.java | Added isMultimodal() default method |
| Test files | Comprehensive test coverage for new functionality |
| Transport version files | Added new transport version for embedding task support |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
jonathan-buttner
left a comment
There was a problem hiding this comment.
Looks good just a few suggestions
| String stringValue = "a string"; | ||
| var input = new InferenceStringGroup(stringValue); | ||
| assertThat(input.inferenceStrings(), contains(new InferenceString(DataType.TEXT, DataFormat.TEXT, stringValue))); | ||
| assertThat(input.containsNonTextEntry(), is(false)); |
There was a problem hiding this comment.
nit: We could use assertFalse or assertTrue for booleans.
There was a problem hiding this comment.
It's just a personal preference, but I like to use one assertion library (Hamcrest in this case) consistently, rather than mixing and matching between JUnit (which is where the assertTrue() and assertFalse() assertions come from) and Hamcrest. The slightly more verbose assertions are worth it to make things more consistent, and if we ever decide to change the preferred assertion library (AssertJ is a personal favourite that I'm sad we don't use) then it's easier to update things if they're all using the same library.
| .map(r -> new ChunkInferenceInput(new InferenceStringGroup(r.input), r.chunkingSettings)) | ||
| .map( | ||
| r -> new ChunkInferenceInput( | ||
| new InferenceStringGroup(singletonList(new InferenceString(InferenceString.DataType.TEXT, r.input))), |
There was a problem hiding this comment.
nit: I think we could use the helper constructor new InferenceStringGroup(r.input) to shorten this right?
Unless we're trying to be more explicit here.
There was a problem hiding this comment.
Yeah, I wanted this specific place to be as explicit as possible, since it's going to need to be updated by someone from Search Relevance (I think?) at some point in the future and they won't have as much context as you or me.
| if (taskType == TaskType.EMBEDDING) { | ||
| multimodalModel = removeAsType(map, MULTIMODAL_MODEL, Boolean.class); | ||
| if (multimodalModel == null) { | ||
| multimodalModel = true; |
There was a problem hiding this comment.
I think we want this to adhere to the default that JinaAIEmbeddingServiceSettings defines. If we add more child class of BaseJinaAIEmbeddingsServiceSettings we'll potentially need to do more checks here too.
What if we have this method take a lambda/interface that handles parsing this field for that particular implementation. If we feel like that's too much work for only two classes and we don't envision having more for now, how about we have the default passed into this method and/or a boolean that controls whether we should parse out the multimodal field?
That way this base class doesn't need specific logic for the task type.
There was a problem hiding this comment.
I was able to refactor a bit so that BaseJinaAIEmbeddingsServiceSettings.fromMap() now takes a lambda that defines what to do with the multimodal_model field, and a functional interface that calls the appropriate constructor, so BaseJinaAIEmbeddingsServiceSettings.fromMap() no longer needs to know about the task type.
| * Returns whether this {@link BaseJinaAIEmbeddingsServiceSettings} defaults to supporting multimodal inputs or not | ||
| * @return {@code true} if these settings default to supporting multimodal inputs | ||
| */ | ||
| public abstract boolean getDefaultMultimodal(); |
There was a problem hiding this comment.
This is my opinion so up to you, but for methods that return a simple result like a boolean, I tend to lean more towards having the child class pass the value as a parameter to the base class's constructor. That way the child classes don't have to implement a whole new method.
|
|
||
| public class JinaAITextEmbeddingServiceSettings extends BaseJinaAIEmbeddingsServiceSettings { | ||
| /** | ||
| * This name is a holdover from before the introduction of {@link JinaAIEmbeddingServiceSettings} to support multimodal embeddings |
There was a problem hiding this comment.
nit: You could add that we can't change it but it really should be ... text_embedding_service_settings.
| embeddingType, | ||
| dimensionsSetByUser | ||
| ); | ||
| if (taskType == TaskType.EMBEDDING) { |
There was a problem hiding this comment.
It's probably overkill but we could also pass a lambda that is basically a constructor and then I think we could make this fromMap a generic static method. We'd be able to remove the taskType checks that way I think and we can remove all the casts in the child classes I think.
| configurationMap.put( | ||
| DIMENSIONS, | ||
| new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription( | ||
| new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.EMBEDDING)).setDescription( |
There was a problem hiding this comment.
Let's reach out to the UI team to confirm that Kibana can handle new task types. @alvarezmelissa87 Embedding is a new task type, does Kibana render the task types dynamically based on the services api response? Or does Kibana need to make a change as well?
- Refactor BaseJinaAIEmbeddingsServiceSettings.fromMap() to use generics - Make multimodalModel field non-optional and introduce abstract optionallyWriteMultimodalField() method to control whether it is written to XContent for implementing classes - Move tests for BaseJinaAIEmbeddingsServiceSettings.fromMap() to the test classes for JinaAITextEmbeddingServiceSettings and JinaAIEmbeddingServiceSettings
jonathan-buttner
left a comment
There was a problem hiding this comment.
Thanks for the changes!
…stic#140323) This commit adds support for the multimodal embedding task type to the JinaAi service. In order to enable this, the existing JinaAIEmbeddingsServiceSettings class has been split into two versions, one for text_embedding and one for embedding, with the common behaviour now found in the BaseJinaAIEmbeddingsServiceSettings class. The embedding task supports using models that accept multimodal inputs as well as models that only accept text inputs, so additional logic has been added to JinaAIEmbeddingsRequestEntity.toXContent() to allow the request sent to Jina to be structured appropriately based on the type of model being used. It is necessary to know whether a given list of inputs contains non-text values, both to ensure that the model being used can support multimodal inputs, and to prevent late chunking being applied, since that setting is not supported by JinaAI for multimodal inputs. To enable this, the InferenceStringGroup class now determines whether any of the InferenceString it contains are non-text values when constructed. The response format used by the embedding task differs slightly from the response format used by the text_embedding task, so changes were made to the JinaAIEmbeddingsResponseEntity class to allow the appropriate DenseEmbeddingResults implementation to be returned based on task type. In order to support per-request task settings, additional parsing logic and a new taskSettings field have been added to the EmbeddingRequest class. This should have been present when the EmbeddingRequest class was first introduced, but it was overlooked at the time. Other changes in this commit: - Consolidate transport version definitions instead of having the same transport version defined in multiple places for JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED and JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED - Add test coverage for new task type - Greatly expand and clean up existing tests for JinaAI model and service settings classes
This commit adds support for the multimodal embedding task type to the JinaAi service. In order to enable this, the existing JinaAIEmbeddingsServiceSettings class has been split into two versions, one for text_embedding and one for embedding, with the common behaviour now found in the BaseJinaAIEmbeddingsServiceSettings class.
The embedding task supports using models that accept multimodal inputs as well as models that only accept text inputs, so additional logic has been added to JinaAIEmbeddingsRequestEntity.toXContent() to allow the request sent to Jina to be structured appropriately based on the type of model being used.
It is necessary to know whether a given list of inputs contains non-text values, both to ensure that the model being used can support multimodal inputs, and to prevent late chunking being applied, since that setting is not supported by JinaAI for multimodal inputs. To enable this, the InferenceStringGroup class now determines whether any of the InferenceString it contains are non-text values when constructed.
The response format used by the embedding task differs slightly from the response format used by the text_embedding task, so changes were made to the JinaAIEmbeddingsResponseEntity class to allow the appropriate DenseEmbeddingResults implementation to be returned based on task type.
In order to support per-request task settings, additional parsing logic and a new taskSettings field have been added to the EmbeddingRequest class. This should have been present when the EmbeddingRequest class was first introduced, but it was overlooked at the time.
Other changes in this commit: