diff --git a/docs/my-website/docs/providers/watsonx/rerank.md b/docs/my-website/docs/providers/watsonx/rerank.md new file mode 100644 index 00000000000..0900ce96781 --- /dev/null +++ b/docs/my-website/docs/providers/watsonx/rerank.md @@ -0,0 +1,52 @@ +# watsonx.ai Rerank + +## Overview + +| Property | Details | +|----------|--------------------------------------------------------------------------| +| Description | watsonx.ai rerank integration | +| Provider Route on LiteLLM | `watsonx/` | +| Supported Operations | `/ml/v1/text/rerank` | +| Link to Provider Doc | [IBM WatsonX.ai ↗](https://cloud.ibm.com/apidocs/watsonx-ai#text-rerank) | + +## Quick Start + +### **LiteLLM SDK** + +```python +import os +from litellm import rerank + +os.environ["WATSONX_APIKEY"] = "YOUR_WATSONX_APIKEY" +os.environ["WATSONX_API_BASE"] = "YOUR_WATSONX_API_BASE" +os.environ["WATSONX_PROJECT_ID"] = "YOUR_WATSONX_PROJECT_ID" + +query="Best programming language for beginners?" +documents=[ + "Python is great for beginners due to simple syntax.", + "JavaScript runs in browsers and is versatile.", + "Rust has a steep learning curve but is very safe.", +] + +response = rerank( + model="watsonx/cross-encoder/ms-marco-minilm-l-12-v2", + query=query, + documents=documents, + top_n=2, + return_documents=True, +) + +print(response) +``` + +### **LiteLLM Proxy** + +```yaml +model_list: + - model_name: cross-encoder/ms-marco-minilm-l-12-v2 + litellm_params: + model: watsonx/cross-encoder/ms-marco-minilm-l-12-v2 + api_key: os.environ/WATSONX_APIKEY + api_base: os.environ/WATSONX_API_BASE + project_id: os.environ/WATSONX_PROJECT_ID +``` diff --git a/docs/my-website/docs/rerank.md b/docs/my-website/docs/rerank.md index 90f685d2bbd..9c76883d7fd 100644 --- a/docs/my-website/docs/rerank.md +++ b/docs/my-website/docs/rerank.md @@ -8,15 +8,15 @@ LiteLLM Follows the [cohere api request / response for the rerank api](https://c ## Overview -| Feature | Supported | Notes | -|---------|-----------|-------| -| Cost Tracking | ✅ | Works with all supported models | -| Logging | ✅ | Works across all integrations | -| End-user Tracking | ✅ | | -| Fallbacks | ✅ | Works between supported models | -| Loadbalancing | ✅ | Works between supported models | -| Guardrails | ✅ | Applies to input query only (not documents) | -| Supported Providers | Cohere, Together AI, Azure AI, DeepInfra, Nvidia NIM, Infinity, Fireworks AI, Voyage AI | | +| Feature | Supported | Notes | +|---------|-----------------------------------------------------------------------------------------------------|-------| +| Cost Tracking | ✅ | Works with all supported models | +| Logging | ✅ | Works across all integrations | +| End-user Tracking | ✅ | | +| Fallbacks | ✅ | Works between supported models | +| Loadbalancing | ✅ | Works between supported models | +| Guardrails | ✅ | Applies to input query only (not documents) | +| Supported Providers | Cohere, Together AI, Azure AI, DeepInfra, Nvidia NIM, Infinity, Fireworks AI, Voyage AI, watsonx.ai | | ## **LiteLLM Python SDK Usage** ### Quick Start @@ -123,17 +123,18 @@ curl http://0.0.0.0:4000/rerank \ #### ⚡️See all supported models and providers at [models.litellm.ai](https://models.litellm.ai/) -| Provider | Link to Usage | -|-------------|--------------------| -| Cohere (v1 + v2 clients) | [Usage](#quick-start) | -| Together AI| [Usage](../docs/providers/togetherai) | -| Azure AI| [Usage](../docs/providers/azure_ai#rerank-endpoint) | -| Jina AI| [Usage](../docs/providers/jina_ai) | -| AWS Bedrock| [Usage](../docs/providers/bedrock#rerank-api) | -| HuggingFace| [Usage](../docs/providers/huggingface_rerank) | -| Infinity| [Usage](../docs/providers/infinity) | -| vLLM| [Usage](../docs/providers/vllm#rerank-endpoint) | -| DeepInfra| [Usage](../docs/providers/deepinfra#rerank-endpoint) | -| Vertex AI| [Usage](../docs/providers/vertex#rerank-api) | -| Fireworks AI| [Usage](../docs/providers/fireworks_ai#rerank-endpoint) | -| Voyage AI| [Usage](../docs/providers/voyage#rerank) | \ No newline at end of file +| Provider | Link to Usage | +|--------------------------|------------------------------------------------------| +| Cohere (v1 + v2 clients) | [Usage](#quick-start) | +| Together AI | [Usage](../docs/providers/togetherai) | +| Azure AI | [Usage](../docs/providers/azure_ai#rerank-endpoint) | +| Jina AI | [Usage](../docs/providers/jina_ai) | +| AWS Bedrock | [Usage](../docs/providers/bedrock#rerank-api) | +| HuggingFace | [Usage](../docs/providers/huggingface_rerank) | +| Infinity | [Usage](../docs/providers/infinity) | +| vLLM | [Usage](../docs/providers/vllm#rerank-endpoint) | +| DeepInfra | [Usage](../docs/providers/deepinfra#rerank-endpoint) | +| Vertex AI | [Usage](../docs/providers/vertex#rerank-api) | +| Fireworks AI | [Usage](../docs/providers/fireworks_ai#rerank-endpoint) | +| Voyage AI | [Usage](../docs/providers/voyage#rerank) | +| IBM watsonx.ai | [Usage](../docs/providers/watsonx/rerank) | \ No newline at end of file diff --git a/litellm/__init__.py b/litellm/__init__.py index 4aaddc3da76..c13ae8c2d1c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1333,6 +1333,7 @@ def set_global_gitlab_config(config: Dict[str, Any]) -> None: from .llms.vertex_ai.rerank.transformation import VertexAIRerankConfig as VertexAIRerankConfig from .llms.fireworks_ai.rerank.transformation import FireworksAIRerankConfig as FireworksAIRerankConfig from .llms.voyage.rerank.transformation import VoyageRerankConfig as VoyageRerankConfig + from .llms.watsonx.rerank.transformation import IBMWatsonXRerankConfig as IBMWatsonXRerankConfig from .llms.clarifai.chat.transformation import ClarifaiConfig as ClarifaiConfig from .llms.ai21.chat.transformation import AI21ChatConfig as AI21ChatConfig from .llms.meta_llama.chat.transformation import LlamaAPIConfig as LlamaAPIConfig diff --git a/litellm/_lazy_imports_registry.py b/litellm/_lazy_imports_registry.py index 2af6ed8f09e..a3dc12c23a3 100644 --- a/litellm/_lazy_imports_registry.py +++ b/litellm/_lazy_imports_registry.py @@ -155,6 +155,7 @@ "VertexAIRerankConfig", "FireworksAIRerankConfig", "VoyageRerankConfig", + "IBMWatsonXRerankConfig", "ClarifaiConfig", "AI21ChatConfig", "LlamaAPIConfig", @@ -671,6 +672,7 @@ "FireworksAIRerankConfig", ), "VoyageRerankConfig": (".llms.voyage.rerank.transformation", "VoyageRerankConfig"), + "IBMWatsonXRerankConfig": (".llms.watsonx.rerank.transformation", "IBMWatsonXRerankConfig"), "ClarifaiConfig": (".llms.clarifai.chat.transformation", "ClarifaiConfig"), "AI21ChatConfig": (".llms.ai21.chat.transformation", "AI21ChatConfig"), "LlamaAPIConfig": (".llms.meta_llama.chat.transformation", "LlamaAPIConfig"), diff --git a/litellm/llms/watsonx/__init__.py b/litellm/llms/watsonx/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/litellm/llms/watsonx/chat/__init__.py b/litellm/llms/watsonx/chat/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/litellm/llms/watsonx/completion/__init__.py b/litellm/llms/watsonx/completion/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/litellm/llms/watsonx/embed/__init__.py b/litellm/llms/watsonx/embed/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/litellm/llms/watsonx/rerank/__init__.py b/litellm/llms/watsonx/rerank/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/litellm/llms/watsonx/rerank/transformation.py b/litellm/llms/watsonx/rerank/transformation.py new file mode 100644 index 00000000000..7b4c2a07c3c --- /dev/null +++ b/litellm/llms/watsonx/rerank/transformation.py @@ -0,0 +1,204 @@ +""" +Transformation logic for IBM watsonx.ai's /ml/v1/text/rerank endpoint. + +Docs - https://cloud.ibm.com/apidocs/watsonx-ai#text-rerank +""" + +import uuid +from typing import Any, Dict, List, Optional, Union, cast + +import httpx + +from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj +from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.watsonx import ( + WatsonXAIEndpoint, +) +from litellm.types.rerank import ( + RerankResponse, + RerankResponseMeta, + RerankTokens, +) + +from ..common_utils import IBMWatsonXMixin, _generate_watsonx_token, _get_api_params + + +class IBMWatsonXRerankConfig(IBMWatsonXMixin, BaseRerankConfig): + """ + IBM watsonx.ai Rerank API configuration + """ + + def get_complete_url( + self, + api_base: Optional[str], + model: str, + optional_params: Optional[dict] = None, + ) -> str: + base_url = self._get_base_url(api_base=api_base) + endpoint = WatsonXAIEndpoint.RERANK.value + + url = base_url.rstrip("/") + endpoint + + params = optional_params or {} + + complete_url = self._add_api_version_to_url(url=url, api_version=(params.get("api_version", None))) + return complete_url + + def get_supported_cohere_rerank_params(self, model: str) -> list: + return [ + "query", + "documents", + "top_n", + "return_documents", + "max_tokens_per_doc", + ] + + def validate_environment( # type: ignore[override] + self, + headers: dict, + model: str, + api_key: Optional[str] = None, + optional_params: Optional[dict] = None, + ) -> Dict: + optional_params = optional_params or {} + + default_headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + if "Authorization" in headers: + return {**default_headers, **headers} + token = cast( + Optional[str], + optional_params.pop("token", None) or get_secret_str("WATSONX_TOKEN"), + ) + zen_api_key = cast( + Optional[str], + optional_params.pop("zen_api_key", None) or get_secret_str("WATSONX_ZENAPIKEY"), + ) + if token: + headers["Authorization"] = f"Bearer {token}" + elif zen_api_key: + headers["Authorization"] = f"ZenApiKey {zen_api_key}" + else: + token = _generate_watsonx_token(api_key=api_key, token=token) + # build auth headers + headers["Authorization"] = f"Bearer {token}" + return {**default_headers, **headers} + + def map_cohere_rerank_params( + self, + non_default_params: Optional[dict], + model: str, + drop_params: bool, + query: str, + documents: List[Union[str, Dict[str, Any]]], + custom_llm_provider: Optional[str] = None, + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, + max_tokens_per_doc: Optional[int] = None, + ) -> Dict: + """ + Map Cohere rerank params to IBM watsonx.ai rerank params + """ + optional_rerank_params = {} + if non_default_params is not None: + for k, v in non_default_params.items(): + if k == "query" and v is not None: + optional_rerank_params["query"] = v + elif k == "documents" and v is not None: + optional_rerank_params["inputs"] = [ + {"text": el} if isinstance(el, str) else el for el in v + ] + elif k == "top_n" and v is not None: + optional_rerank_params.setdefault("parameters", {}).setdefault("return_options", {})["top_n"] = v + elif k == "return_documents" and v is not None and isinstance(v, bool): + optional_rerank_params.setdefault("parameters", {}).setdefault("return_options", {})["inputs"] = v + elif k == "max_tokens_per_doc" and v is not None: + optional_rerank_params.setdefault("parameters", {})["truncate_input_tokens"] = v + + # IBM watsonx.ai require one of below parameters + elif k == "project_id" and v is not None: + optional_rerank_params["project_id"] = v + elif k == "space_id" and v is not None: + optional_rerank_params["space_id"] = v + + return dict(optional_rerank_params) + + def transform_rerank_request( + self, + model: str, + optional_rerank_params: Dict, + headers: dict, + ) -> dict: + """ + Transform request to IBM watsonx.ai rerank format + """ + watsonx_api_params = _get_api_params(params=optional_rerank_params, model=model) + watsonx_auth_payload = self._prepare_payload( + model=model, + api_params=watsonx_api_params, + ) + + return optional_rerank_params | watsonx_auth_payload + + def transform_rerank_response( + self, + model: str, + raw_response: httpx.Response, + model_response: RerankResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + request_data: dict = {}, + optional_params: dict = {}, + litellm_params: dict = {}, + ) -> RerankResponse: + """ + Transform IBM watsonx.ai rerank response to LiteLLM RerankResponse format + """ + try: + raw_response_json = raw_response.json() + except Exception as e: + raise self.get_error_class( + error_message=f"Failed to parse response: {str(e)}", + status_code=raw_response.status_code, + headers=raw_response.headers, + ) + + _results: Optional[List[dict]] = raw_response_json.get("results") + if _results is None: + raise ValueError(f"No results found in the response={raw_response_json}") + + transformed_results = [] + + for result in _results: + transformed_result: Dict[str, Any] = { + "index": result["index"], + "relevance_score": result["score"], + } + + if "input" in result: + if isinstance(result["input"], str): + transformed_result["document"] = {"text": result["input"]} + else: + transformed_result["document"] = result["input"] + + transformed_results.append(transformed_result) + + response_id = raw_response_json.get("id") or raw_response_json.get("model_id") or str(uuid.uuid4()) + + # Extract usage information + _tokens = RerankTokens( + input_tokens=raw_response_json.get("input_token_count", 0), + ) + rerank_meta = RerankResponseMeta(tokens=_tokens) + + return RerankResponse( + id=response_id, + results=transformed_results, # type: ignore + meta=rerank_meta, + ) diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 8910d37fbe7..f47fd6323f0 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -10,6 +10,7 @@ from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from litellm.llms.together_ai.rerank.handler import TogetherAIRerank +from litellm.llms.watsonx.common_utils import IBMWatsonXMixin from litellm.rerank_api.rerank_utils import get_optional_rerank_params from litellm.secret_managers.main import get_secret, get_secret_str from litellm.types.rerank import RerankResponse @@ -29,7 +30,7 @@ async def arerank( model: str, query: str, documents: List[Union[str, Dict[str, Any]]], - custom_llm_provider: Optional[Literal["cohere", "together_ai", "deepinfra", "fireworks_ai", "voyage"]] = None, + custom_llm_provider: Optional[Literal["cohere", "together_ai", "deepinfra", "fireworks_ai", "voyage", "watsonx"]] = None, top_n: Optional[int] = None, rank_fields: Optional[List[str]] = None, return_documents: Optional[bool] = None, @@ -85,6 +86,7 @@ def rerank( # noqa: PLR0915 "deepinfra", "fireworks_ai", "voyage", + "watsonx", ] ] = None, top_n: Optional[int] = None, @@ -478,6 +480,31 @@ def rerank( # noqa: PLR0915 or get_secret_str("VOYAGE_API_BASE") ) + response = base_llm_http_handler.rerank( + model=model, + custom_llm_provider=_custom_llm_provider, + provider_config=rerank_provider_config, + optional_rerank_params=optional_rerank_params, + logging_obj=litellm_logging_obj, + timeout=optional_params.timeout, + api_key=api_key, + api_base=api_base, + _is_async=_is_async, + headers=headers or litellm.headers or {}, + client=client, + model_response=model_response, + ) + elif _custom_llm_provider == litellm.LlmProviders.WATSONX: + credentials = IBMWatsonXMixin.get_watsonx_credentials( + optional_params=dict(optional_params), api_key=dynamic_api_key, api_base=dynamic_api_base + ) + + api_key = credentials["api_key"] + api_base = credentials["api_base"] + + if credentials.get("token") is not None: + optional_rerank_params["token"] = credentials["token"] + response = base_llm_http_handler.rerank( model=model, custom_llm_provider=_custom_llm_provider, diff --git a/litellm/types/llms/watsonx.py b/litellm/types/llms/watsonx.py index 137090b032e..21e58500c6f 100644 --- a/litellm/types/llms/watsonx.py +++ b/litellm/types/llms/watsonx.py @@ -63,6 +63,7 @@ class WatsonXAIEndpoint(str, Enum): EMBEDDINGS = "/ml/v1/text/embeddings" PROMPTS = "/ml/v1/prompts" AVAILABLE_MODELS = "/ml/v1/foundation_model_specs" + RERANK = "/ml/v1/text/rerank" class WatsonXModelPattern(str, Enum): diff --git a/litellm/utils.py b/litellm/utils.py index 0fd21f09919..0e8dada2352 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8145,6 +8145,8 @@ def get_provider_rerank_config( return litellm.FireworksAIRerankConfig() elif litellm.LlmProviders.VOYAGE == provider: return litellm.VoyageRerankConfig() + elif litellm.LlmProviders.WATSONX == provider: + return litellm.IBMWatsonXRerankConfig() return litellm.CohereRerankConfig() @staticmethod diff --git a/tests/test_litellm/llms/watsonx/rerank/__init__.py b/tests/test_litellm/llms/watsonx/rerank/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/test_litellm/llms/watsonx/rerank/test_watsonx_rerank.py b/tests/test_litellm/llms/watsonx/rerank/test_watsonx_rerank.py new file mode 100644 index 00000000000..f50966279b4 --- /dev/null +++ b/tests/test_litellm/llms/watsonx/rerank/test_watsonx_rerank.py @@ -0,0 +1,236 @@ +""" +Tests for IBM watsonx.ai rerank transformation functionality. +""" +import json +import re +import uuid +from unittest.mock import MagicMock + +import httpx +import pytest + +from litellm.llms.watsonx.common_utils import ( + WatsonXAIError, +) +from litellm.llms.watsonx.rerank.transformation import IBMWatsonXRerankConfig +from litellm.types.rerank import RerankResponse + + +class TestIBMWatsonXRerankTransform: + def setup_method(self): + self.config = IBMWatsonXRerankConfig() + self.model = "watsonx/cross-encoder/ms-marco-minilm-l-12-v2" + + def test_get_complete_url(self): + """Test URL generation for IBM watsonx.ai rerank API.""" + + api_base = "https://us-south.ml.cloud.ibm.com" + model = "watsonx/cross-encoder/ms-marco-minilm-l-12-v2" + url = self.config.get_complete_url(api_base, model) + assert url == "https://us-south.ml.cloud.ibm.com/ml/v1/text/rerank?version=2024-03-13" + + def test_map_cohere_rerank_params_basic(self): + """Test basic parameter mapping for IBM watsonx.ai rerank.""" + params = self.config.map_cohere_rerank_params( + non_default_params={ + "query": "hello", + "documents": ["hello", "world"], + "top_n": 2, + "return_documents": True, + "max_tokens_per_doc": 100, + }, + model="test", + drop_params=False, + query="hello", + documents=["hello", "world"], + ) + assert params["query"] == "hello" + assert params["inputs"] == [{"text": "hello"}, {"text": "world"}] + assert params["parameters"]["return_options"]["top_n"] == 2 + assert params["parameters"]["return_options"]["inputs"] is True + assert params["parameters"]["truncate_input_tokens"] == 100 + + def test_transform_rerank_request(self): + """Test request transformation for IBM watsonx.ai format.""" + optional_params = { + "query": "What is the capital of France?", + "documents": [ + "Paris is the capital of France.", + "France is a country in Europe.", + ], + "top_n": 2, + "return_documents": True, + "project_id": uuid.uuid4(), + } + + request_body = self.config.transform_rerank_request( + model="cross-encoder/ms-marco-minilm-l-12-v2", optional_rerank_params=optional_params, headers={} + ) + + assert request_body["model_id"] == "cross-encoder/ms-marco-minilm-l-12-v2" + assert request_body["project_id"] is not None + assert request_body["query"] == "What is the capital of France?" + assert request_body["documents"] == optional_params["documents"] + assert request_body["top_n"] == 2 + assert request_body["return_documents"] is True + + def test_transform_rerank_request_missing_scope(self): + """Test that transform_rerank_request raises error for missing scope.""" + optional_params = { + "documents": ["doc1"], + } + expected_error_msg = re.escape( + "Watsonx project_id and space_id not set. Set WX_PROJECT_ID or WX_SPACE_ID in environment variables or pass in as a parameter." + ) + + with pytest.raises(WatsonXAIError, match=expected_error_msg): + self.config.transform_rerank_request(model=self.model, optional_rerank_params=optional_params, headers={}) + + def test_transform_rerank_response_success(self): + """Test successful response transformation.""" + # Mock IBM watsonx.ai response format + response_data = { + "model_id": self.model, + "results": [ + { + "index": 0, + "score": 6.53515625, + "input": {"text": "Python is great for beginners due to simple syntax."}, + }, + {"index": 1, "score": -7.1875, "input": {"text": "JavaScript runs in browsers and is versatile."}}, + ], + "input_token_count": 62, + } + + # Create mock httpx response + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = response_data + mock_response.status_code = 200 + mock_response.headers = {} + + # Create mock logging object + mock_logging = MagicMock() + + model_response = RerankResponse() + + result = self.config.transform_rerank_response( + model=self.model, + raw_response=mock_response, + model_response=model_response, + logging_obj=mock_logging, + ) + + # Verify response structure + # IBM watsonx.ai doesn't return "id", so it uses "model" as the id + assert result.id == "watsonx/cross-encoder/ms-marco-minilm-l-12-v2" + assert len(result.results) == 2 + assert result.results[0]["index"] == 0 + assert result.results[0]["relevance_score"] == 6.53515625 + assert result.results[0]["document"]["text"] == "Python is great for beginners due to simple syntax." + assert result.results[1]["index"] == 1 + assert result.results[1]["relevance_score"] == -7.1875 + assert result.results[1]["document"]["text"] == "JavaScript runs in browsers and is versatile." + + # # Verify metadata + assert result.meta["tokens"]["input_tokens"] == 62 + + def test_transform_rerank_response_without_documents(self): + """Test response transformation when return_documents is False.""" + response_data = { + "model_id": self.model, + "results": [ + { + "index": 0, + "score": 6.53515625, + }, + { + "index": 1, + "score": -7.1875, + }, + ], + "input_token_count": 62, + } + + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = response_data + mock_response.status_code = 200 + mock_response.headers = {} + + mock_logging = MagicMock() + model_response = RerankResponse() + + result = self.config.transform_rerank_response( + model=self.model, + raw_response=mock_response, + model_response=model_response, + logging_obj=mock_logging, + ) + + # Verify response structure + # IBM watsonx.ai doesn't return "id", so it uses "model" as the id + assert result.id == "watsonx/cross-encoder/ms-marco-minilm-l-12-v2" + assert len(result.results) == 2 + + assert result.results[0]["index"] == 0 + assert result.results[0]["relevance_score"] == 6.53515625 + assert "document" not in result.results[0] + + assert result.results[1]["index"] == 1 + assert result.results[1]["relevance_score"] == -7.1875 + assert "document" not in result.results[1] + + def test_transform_rerank_response_missing_results(self): + """Test that missing results raises ValueError.""" + response_data = { + "model": self.model, + "usage": {"total_tokens": 10}, + } + + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = response_data + mock_response.status_code = 200 + mock_response.headers = {} + + mock_logging = MagicMock() + model_response = RerankResponse() + + expected_error_msg = re.escape("No results found") + + with pytest.raises(ValueError, match=expected_error_msg): + self.config.transform_rerank_response( + model=self.model, + raw_response=mock_response, + model_response=model_response, + logging_obj=mock_logging, + ) + + def test_transform_rerank_response_invalid_json(self): + """Test error handling for invalid JSON response.""" + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "doc", 0) + mock_response.text = "Invalid JSON response" + mock_response.status_code = 500 + mock_response.headers = {} + + mock_logging = MagicMock() + model_response = RerankResponse() + + expected_error_msg = re.escape("Failed to parse response") + + with pytest.raises(Exception, match=expected_error_msg): + self.config.transform_rerank_response( + model=self.model, + raw_response=mock_response, + model_response=model_response, + logging_obj=mock_logging, + ) + + def test_get_supported_cohere_rerank_params(self): + """Test getting supported parameters for IBM watsonx.ai rerank.""" + supported_params = self.config.get_supported_cohere_rerank_params(self.model) + assert "query" in supported_params + assert "documents" in supported_params + assert "top_n" in supported_params + assert "return_documents" in supported_params + assert "max_tokens_per_doc" in supported_params + assert len(supported_params) == 5