Skip to content
Merged
52 changes: 52 additions & 0 deletions docs/my-website/docs/providers/watsonx/rerank.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# watsonx.ai Rerank

## Overview

| Property | Details |
|----------|--------------------------------------------------------------------------|
| Description | watsonx.ai rerank integration |
| Provider Route on LiteLLM | `watsonx/` |
| Supported Operations | `/ml/v1/text/rerank` |
| Link to Provider Doc | [IBM WatsonX.ai ↗](https://cloud.ibm.com/apidocs/watsonx-ai#text-rerank) |

## Quick Start

### **LiteLLM SDK**

```python
import os
from litellm import rerank

os.environ["WATSONX_APIKEY"] = "YOUR_WATSONX_APIKEY"
os.environ["WATSONX_API_BASE"] = "YOUR_WATSONX_API_BASE"
os.environ["WATSONX_PROJECT_ID"] = "YOUR_WATSONX_PROJECT_ID"

query="Best programming language for beginners?"
documents=[
"Python is great for beginners due to simple syntax.",
"JavaScript runs in browsers and is versatile.",
"Rust has a steep learning curve but is very safe.",
]

response = rerank(
model="watsonx/cross-encoder/ms-marco-minilm-l-12-v2",
query=query,
documents=documents,
top_n=2,
return_documents=True,
)

print(response)
```

### **LiteLLM Proxy**

```yaml
model_list:
- model_name: cross-encoder/ms-marco-minilm-l-12-v2
litellm_params:
model: watsonx/cross-encoder/ms-marco-minilm-l-12-v2
api_key: os.environ/WATSONX_APIKEY
api_base: os.environ/WATSONX_API_BASE
project_id: os.environ/WATSONX_PROJECT_ID
```
47 changes: 24 additions & 23 deletions docs/my-website/docs/rerank.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ LiteLLM Follows the [cohere api request / response for the rerank api](https://c

## Overview

| Feature | Supported | Notes |
|---------|-----------|-------|
| Cost Tracking | ✅ | Works with all supported models |
| Logging | ✅ | Works across all integrations |
| End-user Tracking | ✅ | |
| Fallbacks | ✅ | Works between supported models |
| Loadbalancing | ✅ | Works between supported models |
| Guardrails | ✅ | Applies to input query only (not documents) |
| Supported Providers | Cohere, Together AI, Azure AI, DeepInfra, Nvidia NIM, Infinity, Fireworks AI, Voyage AI | |
| Feature | Supported | Notes |
|---------|-----------------------------------------------------------------------------------------------------|-------|
| Cost Tracking | ✅ | Works with all supported models |
| Logging | ✅ | Works across all integrations |
| End-user Tracking | ✅ | |
| Fallbacks | ✅ | Works between supported models |
| Loadbalancing | ✅ | Works between supported models |
| Guardrails | ✅ | Applies to input query only (not documents) |
| Supported Providers | Cohere, Together AI, Azure AI, DeepInfra, Nvidia NIM, Infinity, Fireworks AI, Voyage AI, watsonx.ai | |

## **LiteLLM Python SDK Usage**
### Quick Start
Expand Down Expand Up @@ -123,17 +123,18 @@ curl http://0.0.0.0:4000/rerank \

#### ⚡️See all supported models and providers at [models.litellm.ai](https://models.litellm.ai/)

| Provider | Link to Usage |
|-------------|--------------------|
| Cohere (v1 + v2 clients) | [Usage](#quick-start) |
| Together AI| [Usage](../docs/providers/togetherai) |
| Azure AI| [Usage](../docs/providers/azure_ai#rerank-endpoint) |
| Jina AI| [Usage](../docs/providers/jina_ai) |
| AWS Bedrock| [Usage](../docs/providers/bedrock#rerank-api) |
| HuggingFace| [Usage](../docs/providers/huggingface_rerank) |
| Infinity| [Usage](../docs/providers/infinity) |
| vLLM| [Usage](../docs/providers/vllm#rerank-endpoint) |
| DeepInfra| [Usage](../docs/providers/deepinfra#rerank-endpoint) |
| Vertex AI| [Usage](../docs/providers/vertex#rerank-api) |
| Fireworks AI| [Usage](../docs/providers/fireworks_ai#rerank-endpoint) |
| Voyage AI| [Usage](../docs/providers/voyage#rerank) |
| Provider | Link to Usage |
|--------------------------|------------------------------------------------------|
| Cohere (v1 + v2 clients) | [Usage](#quick-start) |
| Together AI | [Usage](../docs/providers/togetherai) |
| Azure AI | [Usage](../docs/providers/azure_ai#rerank-endpoint) |
| Jina AI | [Usage](../docs/providers/jina_ai) |
| AWS Bedrock | [Usage](../docs/providers/bedrock#rerank-api) |
| HuggingFace | [Usage](../docs/providers/huggingface_rerank) |
| Infinity | [Usage](../docs/providers/infinity) |
| vLLM | [Usage](../docs/providers/vllm#rerank-endpoint) |
| DeepInfra | [Usage](../docs/providers/deepinfra#rerank-endpoint) |
| Vertex AI | [Usage](../docs/providers/vertex#rerank-api) |
| Fireworks AI | [Usage](../docs/providers/fireworks_ai#rerank-endpoint) |
| Voyage AI | [Usage](../docs/providers/voyage#rerank) |
| IBM watsonx.ai | [Usage](../docs/providers/watsonx/rerank) |
1 change: 1 addition & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,7 @@ def set_global_gitlab_config(config: Dict[str, Any]) -> None:
from .llms.vertex_ai.rerank.transformation import VertexAIRerankConfig as VertexAIRerankConfig
from .llms.fireworks_ai.rerank.transformation import FireworksAIRerankConfig as FireworksAIRerankConfig
from .llms.voyage.rerank.transformation import VoyageRerankConfig as VoyageRerankConfig
from .llms.watsonx.rerank.transformation import IBMWatsonXRerankConfig as IBMWatsonXRerankConfig
from .llms.clarifai.chat.transformation import ClarifaiConfig as ClarifaiConfig
from .llms.ai21.chat.transformation import AI21ChatConfig as AI21ChatConfig
from .llms.meta_llama.chat.transformation import LlamaAPIConfig as LlamaAPIConfig
Expand Down
2 changes: 2 additions & 0 deletions litellm/_lazy_imports_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
"VertexAIRerankConfig",
"FireworksAIRerankConfig",
"VoyageRerankConfig",
"IBMWatsonXRerankConfig",
"ClarifaiConfig",
"AI21ChatConfig",
"LlamaAPIConfig",
Expand Down Expand Up @@ -671,6 +672,7 @@
"FireworksAIRerankConfig",
),
"VoyageRerankConfig": (".llms.voyage.rerank.transformation", "VoyageRerankConfig"),
"IBMWatsonXRerankConfig": (".llms.watsonx.rerank.transformation", "IBMWatsonXRerankConfig"),
"ClarifaiConfig": (".llms.clarifai.chat.transformation", "ClarifaiConfig"),
"AI21ChatConfig": (".llms.ai21.chat.transformation", "AI21ChatConfig"),
"LlamaAPIConfig": (".llms.meta_llama.chat.transformation", "LlamaAPIConfig"),
Expand Down
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
204 changes: 204 additions & 0 deletions litellm/llms/watsonx/rerank/transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""
Transformation logic for IBM watsonx.ai's /ml/v1/text/rerank endpoint.

Docs - https://cloud.ibm.com/apidocs/watsonx-ai#text-rerank
"""

import uuid
from typing import Any, Dict, List, Optional, Union, cast

import httpx

from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.watsonx import (
WatsonXAIEndpoint,
)
from litellm.types.rerank import (
RerankResponse,
RerankResponseMeta,
RerankTokens,
)

from ..common_utils import IBMWatsonXMixin, _generate_watsonx_token, _get_api_params


class IBMWatsonXRerankConfig(IBMWatsonXMixin, BaseRerankConfig):
"""
IBM watsonx.ai Rerank API configuration
"""

def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: Optional[dict] = None,
) -> str:
base_url = self._get_base_url(api_base=api_base)
endpoint = WatsonXAIEndpoint.RERANK.value

url = base_url.rstrip("/") + endpoint

params = optional_params or {}

complete_url = self._add_api_version_to_url(url=url, api_version=(params.get("api_version", None)))
return complete_url

def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_n",
"return_documents",
"max_tokens_per_doc",
]

def validate_environment( # type: ignore[override]
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
optional_params: Optional[dict] = None,
) -> Dict:
optional_params = optional_params or {}

default_headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}

if "Authorization" in headers:
return {**default_headers, **headers}
token = cast(
Optional[str],
optional_params.pop("token", None) or get_secret_str("WATSONX_TOKEN"),
)
zen_api_key = cast(
Optional[str],
optional_params.pop("zen_api_key", None) or get_secret_str("WATSONX_ZENAPIKEY"),
)
if token:
headers["Authorization"] = f"Bearer {token}"
elif zen_api_key:
headers["Authorization"] = f"ZenApiKey {zen_api_key}"
else:
token = _generate_watsonx_token(api_key=api_key, token=token)
# build auth headers
headers["Authorization"] = f"Bearer {token}"
return {**default_headers, **headers}

def map_cohere_rerank_params(
self,
non_default_params: Optional[dict],
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> Dict:
"""
Map Cohere rerank params to IBM watsonx.ai rerank params
"""
optional_rerank_params = {}
if non_default_params is not None:
for k, v in non_default_params.items():
if k == "query" and v is not None:
optional_rerank_params["query"] = v
elif k == "documents" and v is not None:
optional_rerank_params["inputs"] = [
{"text": el} if isinstance(el, str) else el for el in v
]
elif k == "top_n" and v is not None:
optional_rerank_params.setdefault("parameters", {}).setdefault("return_options", {})["top_n"] = v
elif k == "return_documents" and v is not None and isinstance(v, bool):
optional_rerank_params.setdefault("parameters", {}).setdefault("return_options", {})["inputs"] = v
elif k == "max_tokens_per_doc" and v is not None:
optional_rerank_params.setdefault("parameters", {})["truncate_input_tokens"] = v

# IBM watsonx.ai require one of below parameters
elif k == "project_id" and v is not None:
optional_rerank_params["project_id"] = v
elif k == "space_id" and v is not None:
optional_rerank_params["space_id"] = v

return dict(optional_rerank_params)

def transform_rerank_request(
self,
model: str,
optional_rerank_params: Dict,
headers: dict,
) -> dict:
"""
Transform request to IBM watsonx.ai rerank format
"""
watsonx_api_params = _get_api_params(params=optional_rerank_params, model=model)
watsonx_auth_payload = self._prepare_payload(
model=model,
api_params=watsonx_api_params,
)

return optional_rerank_params | watsonx_auth_payload

def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
"""
Transform IBM watsonx.ai rerank response to LiteLLM RerankResponse format
"""
try:
raw_response_json = raw_response.json()
except Exception as e:
raise self.get_error_class(
error_message=f"Failed to parse response: {str(e)}",
status_code=raw_response.status_code,
headers=raw_response.headers,
)

_results: Optional[List[dict]] = raw_response_json.get("results")
if _results is None:
raise ValueError(f"No results found in the response={raw_response_json}")

transformed_results = []

for result in _results:
transformed_result: Dict[str, Any] = {
"index": result["index"],
"relevance_score": result["score"],
}

if "input" in result:
if isinstance(result["input"], str):
transformed_result["document"] = {"text": result["input"]}
else:
transformed_result["document"] = result["input"]

transformed_results.append(transformed_result)

response_id = raw_response_json.get("id") or raw_response_json.get("model_id") or str(uuid.uuid4())

# Extract usage information
_tokens = RerankTokens(
input_tokens=raw_response_json.get("input_token_count", 0),
)
rerank_meta = RerankResponseMeta(tokens=_tokens)

return RerankResponse(
id=response_id,
results=transformed_results, # type: ignore
meta=rerank_meta,
)
29 changes: 28 additions & 1 deletion litellm/rerank_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
from litellm.llms.watsonx.common_utils import IBMWatsonXMixin
from litellm.rerank_api.rerank_utils import get_optional_rerank_params
from litellm.secret_managers.main import get_secret, get_secret_str
from litellm.types.rerank import RerankResponse
Expand All @@ -29,7 +30,7 @@ async def arerank(
model: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[Literal["cohere", "together_ai", "deepinfra", "fireworks_ai", "voyage"]] = None,
custom_llm_provider: Optional[Literal["cohere", "together_ai", "deepinfra", "fireworks_ai", "voyage", "watsonx"]] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = None,
Expand Down Expand Up @@ -85,6 +86,7 @@ def rerank( # noqa: PLR0915
"deepinfra",
"fireworks_ai",
"voyage",
"watsonx",
]
] = None,
top_n: Optional[int] = None,
Expand Down Expand Up @@ -478,6 +480,31 @@ def rerank( # noqa: PLR0915
or get_secret_str("VOYAGE_API_BASE")
)

response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
provider_config=rerank_provider_config,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers or litellm.headers or {},
client=client,
model_response=model_response,
)
elif _custom_llm_provider == litellm.LlmProviders.WATSONX:
credentials = IBMWatsonXMixin.get_watsonx_credentials(
optional_params=dict(optional_params), api_key=dynamic_api_key, api_base=dynamic_api_base
)

api_key = credentials["api_key"]
api_base = credentials["api_base"]

if credentials.get("token") is not None:
optional_rerank_params["token"] = credentials["token"]

response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
Expand Down
Loading
Loading