diff --git a/docs/changelog/140331.yaml b/docs/changelog/140331.yaml new file mode 100644 index 0000000000000..19568786d098e --- /dev/null +++ b/docs/changelog/140331.yaml @@ -0,0 +1,6 @@ +pr: 140331 +summary: "[Inference API] Include rerank in supported tasks for IBM watsonx integration" +area: Inference +type: bug +issues: + - 140328 diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/inference.put_watsonx.json b/rest-api-spec/src/main/resources/rest-api-spec/api/inference.put_watsonx.json index e9852eda3048e..cd7b3688ac883 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/inference.put_watsonx.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/inference.put_watsonx.json @@ -28,7 +28,8 @@ "options": [ "text_embedding", "chat_completion", - "completion" + "completion", + "rerank" ] }, "watsonx_inference_id": { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MultiValuedBinaryDocValuesField.java b/server/src/main/java/org/elasticsearch/index/mapper/MultiValuedBinaryDocValuesField.java index 75b418e0c0082..d036b74fe94ba 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MultiValuedBinaryDocValuesField.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MultiValuedBinaryDocValuesField.java @@ -54,7 +54,7 @@ public int count() { protected void writeLenAndValues(BytesStreamOutput out) throws IOException { // sort the ArrayList variant of the collection prior to serializing it into a binary array if (values instanceof ArrayList list) { - list.sort(Comparator.naturalOrder()); + list.sort(Comparator.naturalOrder()); } for (BytesRef value : values) { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index e356edec7d40c..69976cb5d6b82 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -149,6 +149,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { "openshift_ai", "test_reranking_service", "voyageai", + "watsonxai", "hugging_face", "amazon_sagemaker", "elastic" diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 24273fac6b61b..a7ffe1c668651 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -24,6 +24,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -68,21 +69,28 @@ import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.EMBEDDING_MAX_BATCH_SIZE; import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.PROJECT_ID; -public class IbmWatsonxService extends SenderService { - - public static final String NAME = "watsonxai"; +public class IbmWatsonxService extends SenderService implements RerankingInferenceService { private static final String SERVICE_NAME = "IBM watsonx"; private static final EnumSet supportedTaskTypes = EnumSet.of( TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, - TaskType.CHAT_COMPLETION + TaskType.CHAT_COMPLETION, + TaskType.RERANK ); private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new IbmWatsonUnifiedChatCompletionResponseHandler( "IBM watsonx chat completions", OpenAiChatCompletionResponseEntity::fromResponse ); + public static final String NAME = "watsonxai"; + + // IBM watsonx has a single rerank model with a token limit of 512 + // (see https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models-embed.html?context=wx#reranker-overview) + // Using 1 token = 0.75 words as a rough estimate, we get 384 words + // allowing for some headroom, we set the window size below 384 words + public static final int RERANK_WINDOW_SIZE = 350; + public IbmWatsonxService( HttpRequestSender.Factory factory, ServiceComponents serviceComponents, @@ -362,6 +370,11 @@ protected IbmWatsonxActionCreator getActionCreator(Sender sender, ServiceCompone return new IbmWatsonxActionCreator(getSender(), getServiceComponents()); } + @Override + public int rerankerWindowSize(String modelId) { + return RERANK_WINDOW_SIZE; + } + public static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 659e5c70c7677..792fed20970e7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -29,6 +29,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.http.MockResponse; @@ -82,6 +83,7 @@ import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService.RERANK_WINDOW_SIZE; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; @@ -943,7 +945,7 @@ public void testGetConfiguration() throws Exception { { "service": "watsonxai", "name": "IBM watsonx", - "task_types": ["text_embedding", "completion", "chat_completion"], + "task_types": ["text_embedding", "rerank", "completion", "chat_completion"], "configurations": { "project_id": { "description": "", @@ -952,7 +954,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "completion", "chat_completion"] + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -961,7 +963,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "completion", "chat_completion"] + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] }, "api_version": { "description": "The IBM watsonx API version ID to use.", @@ -970,7 +972,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "completion", "chat_completion"] + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -988,7 +990,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "completion", "chat_completion"] + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] } } } @@ -1050,6 +1052,11 @@ public InferenceService createInferenceService() { return createIbmWatsonxService(); } + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat(rerankingInferenceService.rerankerWindowSize("any model"), is(RERANK_WINDOW_SIZE)); + } + private static class IbmWatsonxServiceWithoutAuth extends IbmWatsonxService { IbmWatsonxServiceWithoutAuth(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents, mockClusterServiceEmpty());