Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,6 +1467,7 @@ def set_global_gitlab_config(config: Dict[str, Any]) -> None:
from .llms.azure.chat.gpt_5_transformation import AzureOpenAIGPT5Config as AzureOpenAIGPT5Config
from .llms.azure.completion.transformation import AzureOpenAITextConfig as AzureOpenAITextConfig
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig as HostedVLLMChatConfig
from .llms.hosted_vllm.embedding.transformation import HostedVLLMEmbeddingConfig as HostedVLLMEmbeddingConfig
from .llms.github_copilot.chat.transformation import GithubCopilotConfig as GithubCopilotConfig
from .llms.github_copilot.responses.transformation import GithubCopilotResponsesAPIConfig as GithubCopilotResponsesAPIConfig
from .llms.github_copilot.embedding.transformation import GithubCopilotEmbeddingConfig as GithubCopilotEmbeddingConfig
Expand Down
1,044 changes: 830 additions & 214 deletions litellm/_lazy_imports_registry.py

Large diffs are not rendered by default.

180 changes: 180 additions & 0 deletions litellm/llms/hosted_vllm/embedding/transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""
Hosted VLLM Embedding API Configuration.

This module provides the configuration for hosted VLLM's Embedding API.
VLLM is OpenAI-compatible and supports embeddings via the /v1/embeddings endpoint.

Docs: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
"""

from typing import TYPE_CHECKING, Any, List, Optional, Union

import httpx

from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
from litellm.types.utils import EmbeddingResponse
from litellm.utils import convert_to_model_response_object

if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj

LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any


class HostedVLLMEmbeddingError(BaseLLMException):
"""Exception class for Hosted VLLM Embedding errors."""

pass


class HostedVLLMEmbeddingConfig(BaseEmbeddingConfig):
"""
Configuration for Hosted VLLM's Embedding API.

Reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
"""

def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate environment and set up headers for Hosted VLLM API.
"""
if api_key is None:
api_key = get_secret_str("HOSTED_VLLM_API_KEY") or "fake-api-key"

default_headers = {
"Content-Type": "application/json",
}

# Only add Authorization header if api_key is not "fake-api-key"
if api_key and api_key != "fake-api-key":
default_headers["Authorization"] = f"Bearer {api_key}"

# Merge with existing headers (user's headers take priority)
return {**default_headers, **headers}

def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for Hosted VLLM Embedding API endpoint.
"""
if api_base is None:
api_base = get_secret_str("HOSTED_VLLM_API_BASE")
if api_base is None:
raise ValueError("api_base is required for hosted_vllm embeddings")

# Remove trailing slashes
api_base = api_base.rstrip("/")

# Ensure the URL ends with /embeddings
if not api_base.endswith("/embeddings"):
api_base = f"{api_base}/embeddings"

return api_base

def transform_embedding_request(
self,
model: str,
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
"""
Transform embedding request to Hosted VLLM format (OpenAI-compatible).
"""
# Ensure input is a list
if isinstance(input, str):
input = [input]

# Strip 'hosted_vllm/' prefix if present
if model.startswith("hosted_vllm/"):
model = model.replace("hosted_vllm/", "", 1)

return {
"model": model,
"input": input,
**optional_params,
}

def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
request_data: dict,
optional_params: dict,
litellm_params: dict,
) -> EmbeddingResponse:
"""
Transform embedding response from Hosted VLLM format (OpenAI-compatible).
"""
logging_obj.post_call(original_response=raw_response.text)

# VLLM returns standard OpenAI-compatible embedding response
response_json = raw_response.json()

return convert_to_model_response_object(
response_object=response_json,
model_response_object=model_response,
response_type="embedding",
)

def get_supported_openai_params(self, model: str) -> list:
"""
Get list of supported OpenAI parameters for Hosted VLLM embeddings.
"""
return [
"timeout",
"dimensions",
"encoding_format",
"user",
]

def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map OpenAI parameters to Hosted VLLM format.
"""
for param, value in non_default_params.items():
if param in self.get_supported_openai_params(model):
optional_params[param] = value
return optional_params

def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
"""
Get the error class for Hosted VLLM errors.
"""
return HostedVLLMEmbeddingError(
message=error_message,
status_code=status_code,
headers=headers,
)
88 changes: 69 additions & 19 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2373,6 +2373,33 @@ def completion( # type: ignore # noqa: PLR0915
or "https://api.minimax.io/v1"
)

response = base_llm_http_handler.completion(
model=model,
messages=messages,
api_base=api_base,
custom_llm_provider=custom_llm_provider,
model_response=model_response,
encoding=_get_encoding(),
logging_obj=logging,
optional_params=optional_params,
timeout=timeout,
litellm_params=litellm_params,
shared_session=shared_session,
acompletion=acompletion,
stream=stream,
api_key=api_key,
headers=headers,
client=client,
provider_config=provider_config,
)
logging.post_call(
input=messages, api_key=api_key, original_response=response
)
elif custom_llm_provider == "hosted_vllm":
api_base = (
api_base or litellm.api_base or get_secret_str("HOSTED_VLLM_API_BASE")
)

response = base_llm_http_handler.completion(
model=model,
messages=messages,
Expand Down Expand Up @@ -3611,9 +3638,9 @@ def completion( # type: ignore # noqa: PLR0915
"aws_region_name" not in optional_params
or optional_params["aws_region_name"] is None
):
optional_params[
"aws_region_name"
] = aws_bedrock_client.meta.region_name
optional_params["aws_region_name"] = (
aws_bedrock_client.meta.region_name
)

bedrock_route = BedrockModelInfo.get_bedrock_route(model)
if bedrock_route == "converse":
Expand Down Expand Up @@ -4773,9 +4800,32 @@ def embedding( # noqa: PLR0915
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "hosted_vllm":
api_base = (
api_base or litellm.api_base or get_secret_str("HOSTED_VLLM_API_BASE")
)

# set API KEY
if api_key is None:
api_key = litellm.api_key or get_secret_str("HOSTED_VLLM_API_KEY")

response = base_llm_http_handler.embedding(
model=model,
input=input,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
litellm_params=litellm_params_dict,
headers=headers or {},
)
elif (
custom_llm_provider == "openai_like"
or custom_llm_provider == "hosted_vllm"
or custom_llm_provider == "llamafile"
or custom_llm_provider == "lm_studio"
):
Expand Down Expand Up @@ -5948,9 +5998,9 @@ def adapter_completion(
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)

response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
translated_response: Optional[
Union[BaseModel, AdapterCompletionStreamWrapper]
] = None
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
None
)
if isinstance(response, ModelResponse):
translated_response = translation_obj.translate_completion_output_params(
response=response
Expand Down Expand Up @@ -6655,9 +6705,9 @@ def speech( # noqa: PLR0915
ElevenLabsTextToSpeechConfig.ELEVENLABS_QUERY_PARAMS_KEY
] = query_params

litellm_params_dict[
ElevenLabsTextToSpeechConfig.ELEVENLABS_VOICE_ID_KEY
] = voice_id
litellm_params_dict[ElevenLabsTextToSpeechConfig.ELEVENLABS_VOICE_ID_KEY] = (
voice_id
)

if api_base is not None:
litellm_params_dict["api_base"] = api_base
Expand Down Expand Up @@ -7163,9 +7213,9 @@ def stream_chunk_builder( # noqa: PLR0915
]

if len(content_chunks) > 0:
response["choices"][0]["message"][
"content"
] = processor.get_combined_content(content_chunks)
response["choices"][0]["message"]["content"] = (
processor.get_combined_content(content_chunks)
)

thinking_blocks = [
chunk
Expand All @@ -7176,9 +7226,9 @@ def stream_chunk_builder( # noqa: PLR0915
]

if len(thinking_blocks) > 0:
response["choices"][0]["message"][
"thinking_blocks"
] = processor.get_combined_thinking_content(thinking_blocks)
response["choices"][0]["message"]["thinking_blocks"] = (
processor.get_combined_thinking_content(thinking_blocks)
)

reasoning_chunks = [
chunk
Expand All @@ -7189,9 +7239,9 @@ def stream_chunk_builder( # noqa: PLR0915
]

if len(reasoning_chunks) > 0:
response["choices"][0]["message"][
"reasoning_content"
] = processor.get_combined_reasoning_content(reasoning_chunks)
response["choices"][0]["message"]["reasoning_content"] = (
processor.get_combined_reasoning_content(reasoning_chunks)
)

annotation_chunks = [
chunk
Expand Down
Loading
Loading