Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions enterprise/litellm_enterprise/proxy/hooks/managed_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ async def async_pre_call_hook( # noqa: PLR0915
if (
call_type == CallTypes.afile_content.value
or call_type == CallTypes.afile_delete.value
or call_type == CallTypes.afile_retrieve.value
or call_type == CallTypes.afile_content.value
):
await self.check_managed_file_id_access(data, user_api_key_dict)

Expand Down Expand Up @@ -433,12 +435,16 @@ async def async_pre_call_hook( # noqa: PLR0915
data["model_file_id_mapping"] = model_file_id_mapping
elif (
call_type == CallTypes.aretrieve_batch.value
or call_type == CallTypes.acancel_batch.value
or call_type == CallTypes.acancel_fine_tuning_job.value
or call_type == CallTypes.aretrieve_fine_tuning_job.value
):
accessor_key: Optional[str] = None
retrieve_object_id: Optional[str] = None
if call_type == CallTypes.aretrieve_batch.value:
if (
call_type == CallTypes.aretrieve_batch.value
or call_type == CallTypes.acancel_batch.value
):
accessor_key = "batch_id"
elif (
call_type == CallTypes.acancel_fine_tuning_job.value
Expand All @@ -454,6 +460,8 @@ async def async_pre_call_hook( # noqa: PLR0915
if retrieve_object_id
else False
)
print(f"🔥potential_llm_object_id: {potential_llm_object_id}")
print(f"🔥retrieve_object_id: {retrieve_object_id}")
if potential_llm_object_id and retrieve_object_id:
## VALIDATE USER HAS ACCESS TO THE OBJECT ##
if not await self.can_user_call_unified_object_id(
Expand Down Expand Up @@ -966,8 +974,10 @@ async def afile_delete(
delete_response = None
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
if specific_model_file_id_mapping:
# Remove conflicting keys from data to avoid duplicate keyword arguments
filtered_data = {k: v for k, v in data.items() if k not in ("model", "file_id")}
for model_id, model_file_id in specific_model_file_id_mapping.items():
delete_response = await llm_router.afile_delete(model=model_id, file_id=model_file_id, **data) # type: ignore
delete_response = await llm_router.afile_delete(model=model_id, file_id=model_file_id, **filtered_data) # type: ignore

stored_file_object = await self.delete_unified_file_id(
file_id, litellm_parent_otel_span
Expand Down
5 changes: 2 additions & 3 deletions litellm/batches/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
Batch,
CancelBatchRequest,
CreateBatchRequest,
RetrieveBatchRequest,
Expand Down Expand Up @@ -868,7 +867,7 @@ async def acancel_batch(
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Batch:
) -> LiteLLMBatch:
"""
Async: Cancels a batch.

Expand Down Expand Up @@ -912,7 +911,7 @@ def cancel_batch(
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
"""
Cancels a batch.

Expand Down
23 changes: 18 additions & 5 deletions litellm/llms/azure/batches/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
from typing import Any, Coroutine, Optional, Union, cast

import httpx

from openai import AsyncOpenAI, OpenAI

from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI
from litellm.types.llms.openai import (
Batch,
CancelBatchRequest,
CreateBatchRequest,
RetrieveBatchRequest,
Expand Down Expand Up @@ -130,9 +128,9 @@ async def acancel_batch(
self,
cancel_batch_data: CancelBatchRequest,
client: Union[AsyncAzureOpenAI, AsyncOpenAI],
) -> Batch:
) -> LiteLLMBatch:
response = await client.batches.cancel(**cancel_batch_data)
return response
return LiteLLMBatch(**response.model_dump())

def cancel_batch(
self,
Expand Down Expand Up @@ -160,8 +158,23 @@ def cancel_batch(
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)

if _is_async is True:
if not isinstance(azure_client, (AsyncAzureOpenAI, AsyncOpenAI)):
raise ValueError(
"Azure client is not an instance of AsyncAzureOpenAI or AsyncOpenAI. Make sure you passed an async client."
)
return self.acancel_batch( # type: ignore
cancel_batch_data=cancel_batch_data, client=azure_client
)

# At this point, azure_client is guaranteed to be a sync client
if not isinstance(azure_client, (AzureOpenAI, OpenAI)):
raise ValueError(
"Azure client is not an instance of AzureOpenAI or OpenAI. Make sure you passed a sync client."
)
response = azure_client.batches.cancel(**cancel_batch_data)
return response
return LiteLLMBatch(**response.model_dump())

async def alist_batches(
self,
Expand Down
11 changes: 8 additions & 3 deletions litellm/llms/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1923,10 +1923,10 @@ async def acancel_batch(
self,
cancel_batch_data: CancelBatchRequest,
openai_client: AsyncOpenAI,
) -> Batch:
) -> LiteLLMBatch:
verbose_logger.debug("async cancelling batch, args= %s", cancel_batch_data)
response = await openai_client.batches.cancel(**cancel_batch_data)
return response
return LiteLLMBatch(**response.model_dump())

def cancel_batch(
self,
Expand Down Expand Up @@ -1962,8 +1962,13 @@ def cancel_batch(
cancel_batch_data=cancel_batch_data, openai_client=openai_client
)

# At this point, openai_client is guaranteed to be a sync OpenAI client
if not isinstance(openai_client, OpenAI):
raise ValueError(
"OpenAI client is not an instance of OpenAI. Make sure you passed a sync OpenAI client."
)
response = openai_client.batches.cancel(**cancel_batch_data)
return response
return LiteLLMBatch(**response.model_dump())

async def alist_batches(
self,
Expand Down
2 changes: 2 additions & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ class LiteLLMRoutes(enum.Enum):
"/batches",
"/v1/batches/{batch_id}",
"/batches/{batch_id}",
"/v1/batches/{batch_id}/cancel",
"/batches/{batch_id}/cancel",
# files
"/v1/files",
"/files",
Expand Down
2 changes: 0 additions & 2 deletions litellm/proxy/auth/user_api_key_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,6 @@ async def user_api_key_auth(
request_data=request_data, request=request
)
route: str = get_request_route(request=request)

## CHECK IF ROUTE IS ALLOWED

user_api_key_auth_obj = await _user_api_key_auth_builder(
Expand All @@ -1263,7 +1262,6 @@ async def user_api_key_auth(
user_api_key_auth_obj.end_user_id = end_user_id

user_api_key_auth_obj.request_route = normalize_request_route(route)

return user_api_key_auth_obj


Expand Down
70 changes: 33 additions & 37 deletions litellm/proxy/batches_endpoints/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
_is_base64_encoded_unified_file_id,
decode_model_from_file_id,
encode_file_id_with_model,
get_batch_id_from_unified_batch_id,
get_credentials_for_model,
get_model_id_from_unified_batch_id,
get_models_from_unified_file_id,
get_original_file_id,
prepare_data_with_credentials,
Expand Down Expand Up @@ -382,25 +380,6 @@ async def retrieve_batch(
**data # type: ignore
)

# Re-encode all IDs in the response
if response:
if hasattr(response, "id") and response.id:
response.id = batch_id # Keep the encoded batch ID

if hasattr(response, "input_file_id") and response.input_file_id:
response.input_file_id = encode_file_id_with_model(
file_id=response.input_file_id, model=model_from_id
)

if hasattr(response, "output_file_id") and response.output_file_id:
response.output_file_id = encode_file_id_with_model(
file_id=response.output_file_id, model=model_from_id
)

if hasattr(response, "error_file_id") and response.error_file_id:
response.error_file_id = encode_file_id_with_model(
file_id=response.error_file_id, model=model_from_id
)

verbose_proxy_logger.debug(
f"Retrieved batch using model: {model_from_id}, original_id: {original_batch_id}"
Expand Down Expand Up @@ -695,15 +674,31 @@ async def cancel_batch(

data: Dict = {}
try:
data = await _read_request_body(request=request)
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
)

# Check for encoded batch ID with model info
model_from_id = decode_model_from_file_id(batch_id)

# Create CancelBatchRequest with batch_id to enable ownership checking
_cancel_batch_request = CancelBatchRequest(
batch_id=batch_id,
)
data = cast(dict, _cancel_batch_request)

unified_batch_id = _is_base64_encoded_unified_file_id(batch_id)

base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
(
data,
litellm_logging_obj,
) = await base_llm_response_processor.common_processing_pre_call_logic(
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_logging_obj=proxy_logging_obj,
proxy_config=proxy_config,
route_type="acancel_batch",
)

# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
Expand Down Expand Up @@ -751,17 +746,13 @@ async def cancel_batch(
},
)

model = (
get_model_id_from_unified_batch_id(unified_batch_id)
if unified_batch_id
else None
)

model_batch_id = get_batch_id_from_unified_batch_id(unified_batch_id)

data["batch_id"] = model_batch_id

response = await llm_router.acancel_batch(model=model, **data) # type: ignore
# Hook has already extracted model and unwrapped batch_id into data dict
response = await llm_router.acancel_batch(**data) # type: ignore
response._hidden_params["unified_batch_id"] = unified_batch_id

# Ensure model_id is set for the post_call_success_hook to re-encode IDs
if not response._hidden_params.get("model_id") and data.get("model"):
response._hidden_params["model_id"] = data["model"]

# SCENARIO 3: Fallback to custom_llm_provider (uses env variables)
else:
Expand All @@ -775,6 +766,11 @@ async def cancel_batch(
**_cancel_batch_data,
)

### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)

### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
Expand Down
4 changes: 4 additions & 0 deletions litellm/proxy/common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,10 @@ async def common_processing_pre_call_logic(
"acreate_batch",
"aretrieve_batch",
"alist_batches",
"acancel_batch",
"afile_content",
"afile_retrieve",
"afile_delete",
"atext_completion",
"acreate_fine_tuning_job",
"acancel_fine_tuning_job",
Expand Down Expand Up @@ -606,6 +608,8 @@ async def base_process_llm_request(
"aget_interaction",
"adelete_interaction",
"acancel_interaction",
"acancel_batch",
"afile_delete",
],
proxy_logging_obj: ProxyLogging,
general_settings: dict,
Expand Down
24 changes: 21 additions & 3 deletions litellm/proxy/openai_files_endpoints/files_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ async def get_file(
version,
)

data: Dict = {}
data: Dict = {"file_id": file_id}
try:

custom_llm_provider = (
Expand Down Expand Up @@ -992,7 +992,7 @@ async def delete_file(
version,
)

data: Dict = {}
data: Dict = {"file_id": file_id}
try:
custom_llm_provider = (
provider
Expand All @@ -1001,6 +1001,22 @@ async def delete_file(
or await get_custom_llm_provider_from_request_body(request=request)
or "openai"
)

# Call common_processing_pre_call_logic to trigger permission checks
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
(
data,
litellm_logging_obj,
) = await base_llm_response_processor.common_processing_pre_call_logic(
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_logging_obj=proxy_logging_obj,
proxy_config=proxy_config,
route_type="afile_delete",
)

# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
Expand Down Expand Up @@ -1060,11 +1076,13 @@ async def delete_file(
code=500,
)

# Remove file_id from data to avoid duplicate keyword argument
data_without_file_id = {k: v for k, v in data.items() if k != "file_id"}
response = await managed_files_obj.afile_delete(
file_id=file_id,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
llm_router=llm_router,
**data,
**data_without_file_id,
)
else:
response = await litellm.afile_delete(
Expand Down
6 changes: 3 additions & 3 deletions litellm/types/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,9 @@ class AnthropicMessagesDocumentParam(TypedDict, total=False):
citations: Optional[CitationsObject]


class AnthropicMessagesToolResultContent(TypedDict):
type: Literal["text"]
text: str
class AnthropicMessagesToolResultContent(TypedDict, total=False):
type: Required[Literal["text"]]
text: Required[str]
cache_control: Optional[Union[dict, ChatCompletionCachedContent]]


Expand Down
2 changes: 2 additions & 0 deletions litellm/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ class CallTypes(str, Enum):
acreate_batch = "acreate_batch"
aretrieve_batch = "aretrieve_batch"
retrieve_batch = "retrieve_batch"
acancel_batch = "acancel_batch"
cancel_batch = "cancel_batch"
pass_through = "pass_through_endpoint"
anthropic_messages = "anthropic_messages"
get_assistants = "get_assistants"
Expand Down
Loading
Loading