Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from litellm._uuid import uuid
from datetime import datetime
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Optional

from litellm._logging import verbose_proxy_logger

Expand Down Expand Up @@ -35,14 +35,11 @@ async def check_batch_cost(self):
- if not, return False
- if so, return True
"""
from litellm_enterprise.proxy.hooks.managed_files import (
_PROXY_LiteLLMManagedFiles,
)

from litellm.batches.batch_utils import (
_get_file_content_as_dictionary,
calculate_batch_cost_and_usage,
)
from litellm.files.main import afile_content
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.proxy.openai_files_endpoints.common_utils import (
Expand Down Expand Up @@ -102,27 +99,29 @@ async def check_batch_cost(self):
continue

## RETRIEVE THE BATCH JOB OUTPUT FILE
managed_files_obj = cast(
Optional[_PROXY_LiteLLMManagedFiles],
self.proxy_logging_obj.get_proxy_hook("managed_files"),
)
if (
response.status == "completed"
and response.output_file_id is not None
and managed_files_obj is not None
):
verbose_proxy_logger.info(
f"Batch ID: {batch_id} is complete, tracking cost and usage"
)
# track cost
model_file_id_mapping = {
response.output_file_id: {model_id: response.output_file_id}
}
_file_content = await managed_files_obj.afile_content(
file_id=response.output_file_id,
litellm_parent_otel_span=None,
llm_router=self.llm_router,
model_file_id_mapping=model_file_id_mapping,

# This background job runs as default_user_id, so going through the HTTP endpoint
# would trigger check_managed_file_id_access and get 403. Instead, extract the raw
# provider file ID and call afile_content directly with deployment credentials.
raw_output_file_id = response.output_file_id
decoded = _is_base64_encoded_unified_file_id(raw_output_file_id)
if decoded:
try:
raw_output_file_id = decoded.split("llm_output_file_id,")[1].split(";")[0]
except (IndexError, AttributeError):
pass

credentials = self.llm_router.get_deployment_credentials_with_provider(model_id) or {}
_file_content = await afile_content(
file_id=raw_output_file_id,
**credentials,
)

file_content_as_dict = _get_file_content_as_dictionary(
Expand All @@ -143,11 +142,15 @@ async def check_batch_cost(self):
custom_llm_provider=custom_llm_provider,
)

# Pass deployment model_info so custom batch pricing
# (input_cost_per_token_batches etc.) is used for cost calc
deployment_model_info = deployment_info.model_info.model_dump() if deployment_info.model_info else {}
batch_cost, batch_usage, batch_models = (
await calculate_batch_cost_and_usage(
file_content_dictionary=file_content_as_dict,
custom_llm_provider=llm_provider, # type: ignore
model_name=model_name,
model_info=deployment_model_info,
)
)
logging_obj = LiteLLMLogging(
Expand Down
35 changes: 25 additions & 10 deletions enterprise/litellm_enterprise/proxy/hooks/managed_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,14 @@ async def can_user_call_unified_file_id(

if managed_file:
return managed_file.created_by == user_id
return False
raise HTTPException(
status_code=404,
detail=f"File not found: {unified_file_id}",
)

async def can_user_call_unified_object_id(
self, unified_object_id: str, user_api_key_dict: UserAPIKeyAuth
) -> bool:
## check if the user has access to the unified object id
## check if the user has access to the unified object id
user_id = user_api_key_dict.user_id
managed_object = (
Expand All @@ -246,7 +248,10 @@ async def can_user_call_unified_object_id(

if managed_object:
return managed_object.created_by == user_id
return True # don't raise error if managed object is not found
raise HTTPException(
status_code=404,
detail=f"Object not found: {unified_object_id}",
)

async def list_user_batches(
self,
Expand Down Expand Up @@ -911,15 +916,22 @@ async def async_post_call_success_hook(
)
setattr(response, file_attr, unified_file_id)

# Fetch the actual file object from the provider
# Use llm_router credentials when available. Without credentials,
# Azure and other auth-required providers return 500/401.
file_object = None
try:
# Use litellm to retrieve the file object from the provider
from litellm import afile_retrieve
file_object = await afile_retrieve(
custom_llm_provider=model_name.split("/")[0] if model_name and "/" in model_name else "openai",
file_id=original_file_id
)
from litellm.proxy.proxy_server import llm_router as _llm_router
if _llm_router is not None and model_id:
_creds = _llm_router.get_deployment_credentials_with_provider(model_id) or {}
file_object = await litellm.afile_retrieve(
file_id=original_file_id,
**_creds,
)
else:
file_object = await litellm.afile_retrieve(
custom_llm_provider=model_name.split("/")[0] if model_name and "/" in model_name else "openai",
file_id=original_file_id,
)
verbose_logger.debug(
f"Successfully retrieved file object for {file_attr}={original_file_id}"
)
Expand Down Expand Up @@ -1004,7 +1016,10 @@ async def afile_retrieve(
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")

# Case 2: Managed file and the file object exists in the database
# The stored file_object has the raw provider ID. Replace with the unified ID
# so callers see a consistent ID (matching Case 3 which does response.id = file_id).
if stored_file_object and stored_file_object.file_object:
stored_file_object.file_object.id = file_id
return stored_file_object.file_object

# Case 3: Managed file exists in the database but not the file object (for. e.g the batch task might not have run)
Expand Down
4 changes: 3 additions & 1 deletion litellm/_service_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,12 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
_duration, type(_duration)
)
) # invalid _duration value
# Batch polling callbacks (check_batch_cost) don't include call_type in kwargs.
# Use .get() to avoid KeyError.
await self.async_service_success_hook(
service=ServiceTypes.LITELLM,
duration=_duration,
call_type=kwargs["call_type"],
call_type=kwargs.get("call_type", "unknown")
)
except Exception as e:
raise e
36 changes: 30 additions & 6 deletions litellm/batches/batch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,22 @@ async def calculate_batch_cost_and_usage(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"],
model_name: Optional[str] = None,
model_info: Optional[dict] = None,
) -> Tuple[float, Usage, List[str]]:
"""
Calculate the cost and usage of a batch
Calculate the cost and usage of a batch.

Args:
model_info: Optional deployment-level model info with custom batch
pricing. Threaded through to batch_cost_calculator so that
deployment-specific pricing (e.g. input_cost_per_token_batches)
is used instead of the global cost map.
"""
batch_cost = _batch_cost_calculator(
custom_llm_provider=custom_llm_provider,
file_content_dictionary=file_content_dictionary,
model_name=model_name,
model_info=model_info,
)
batch_usage = _get_batch_job_total_usage_from_file_content(
file_content_dictionary=file_content_dictionary,
Expand Down Expand Up @@ -94,6 +102,7 @@ def _batch_cost_calculator(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"] = "openai",
model_name: Optional[str] = None,
model_info: Optional[dict] = None,
) -> float:
"""
Calculate the cost of a batch based on the output file id
Expand All @@ -108,6 +117,7 @@ def _batch_cost_calculator(
total_cost = _get_batch_job_cost_from_file_content(
file_content_dictionary=file_content_dictionary,
custom_llm_provider=custom_llm_provider,
model_info=model_info,
)
verbose_logger.debug("total_cost=%s", total_cost)
return total_cost
Expand Down Expand Up @@ -290,10 +300,13 @@ def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]:
def _get_batch_job_cost_from_file_content(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"] = "openai",
model_info: Optional[dict] = None,
) -> float:
"""
Get the cost of a batch job from the file content
"""
from litellm.cost_calculator import batch_cost_calculator

try:
total_cost: float = 0.0
# parse the file content as json
Expand All @@ -303,11 +316,22 @@ def _get_batch_job_cost_from_file_content(
for _item in file_content_dictionary:
if _batch_response_was_successful(_item):
_response_body = _get_response_from_batch_job_output_file(_item)
total_cost += litellm.completion_cost(
completion_response=_response_body,
custom_llm_provider=custom_llm_provider,
call_type=CallTypes.aretrieve_batch.value,
)
if model_info is not None:
usage = _get_batch_job_usage_from_response_body(_response_body)
model = _response_body.get("model", "")
prompt_cost, completion_cost = batch_cost_calculator(
usage=usage,
model=model,
custom_llm_provider=custom_llm_provider,
model_info=model_info,
)
total_cost += prompt_cost + completion_cost
else:
total_cost += litellm.completion_cost(
completion_response=_response_body,
custom_llm_provider=custom_llm_provider,
call_type=CallTypes.aretrieve_batch.value,
)
verbose_logger.debug("total_cost=%s", total_cost)
return total_cost
except Exception as e:
Expand Down
22 changes: 15 additions & 7 deletions litellm/cost_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,9 +1892,16 @@ def batch_cost_calculator(
usage: Usage,
model: str,
custom_llm_provider: Optional[str] = None,
model_info: Optional[dict] = None,
) -> Tuple[float, float]:
"""
Calculate the cost of a batch job
Calculate the cost of a batch job.

Args:
model_info: Optional deployment-level model info containing custom
batch pricing (e.g. input_cost_per_token_batches). When provided,
skips the global litellm.get_model_info() lookup so that
deployment-specific pricing is used.
"""

_, custom_llm_provider, _, _ = litellm.get_llm_provider(
Expand All @@ -1907,12 +1914,13 @@ def batch_cost_calculator(
custom_llm_provider,
)

try:
model_info: Optional[ModelInfo] = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
except Exception:
model_info = None
if model_info is None:
try:
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
except Exception:
model_info = None

if not model_info:
return 0.0, 0.0
Expand Down
8 changes: 7 additions & 1 deletion litellm/integrations/s3_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,14 @@ async def _async_log_event_base(self, kwargs, response_obj, start_time, end_time
standard_logging_payload=kwargs.get("standard_logging_object", None),
)

# afile_delete and other non-model call types never produce a standard_logging_object,
# so s3_batch_logging_element is None. Skip gracefully instead of raising ValueError.
if s3_batch_logging_element is None:
raise ValueError("s3_batch_logging_element is None")
verbose_logger.debug(
"s3 Logging - skipping event, no standard_logging_object for call_type=%s",
kwargs.get("call_type", "unknown"),
)
return

verbose_logger.debug(
"\ns3 Logger - Logging payload = %s", s3_batch_logging_element
Expand Down
11 changes: 11 additions & 0 deletions litellm/proxy/batches_endpoints/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_models_from_unified_file_id,
get_original_file_id,
prepare_data_with_credentials,
resolve_input_file_id_to_unified,
update_batch_in_database,
)
from litellm.proxy.utils import handle_exception_on_proxy, is_known_model
Expand Down Expand Up @@ -377,6 +378,11 @@ async def retrieve_batch(
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)

# async_post_call_success_hook replaces batch.id and output_file_id with unified IDs
# but not input_file_id. Resolve raw provider ID to unified ID.
if unified_batch_id:
await resolve_input_file_id_to_unified(response, prisma_client)

asyncio.create_task(
proxy_logging_obj.update_request_status(
Expand Down Expand Up @@ -479,6 +485,11 @@ async def retrieve_batch(
data=data, user_api_key_dict=user_api_key_dict, response=response
)

# Fix: bug_feb14_batch_retrieve_returns_raw_input_file_id
# Resolve raw provider input_file_id to unified ID.
if unified_batch_id:
await resolve_input_file_id_to_unified(response, prisma_client)

### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
Expand Down
3 changes: 2 additions & 1 deletion litellm/proxy/hooks/batch_rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,10 @@ async def count_input_file_usage(
from litellm.proxy.openai_files_endpoints.common_utils import (
_is_base64_encoded_unified_file_id,
)
# Managed files require bypassing the HTTP endpoint (which runs access-check hooks)
# and calling the managed files hook directly with the user's credentials.
is_managed_file = _is_base64_encoded_unified_file_id(file_id)
if is_managed_file and user_api_key_dict is not None:
# For managed files, use the managed files hook directly
file_content = await self._fetch_managed_file_content(
file_id=file_id,
user_api_key_dict=user_api_key_dict,
Expand Down
8 changes: 8 additions & 0 deletions litellm/proxy/hooks/proxy_track_cost_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,14 @@ async def _PROXY_track_cost_callback(
max_budget=end_user_max_budget,
)
else:
# Non-model call types (health checks, afile_delete) have no model or standard_logging_object.
# Use .get() for "stream" to avoid KeyError on health checks.
if sl_object is None and not kwargs.get("model"):
verbose_proxy_logger.warning(
"Cost tracking - skipping, no standard_logging_object and no model for call_type=%s",
kwargs.get("call_type", "unknown"),
)
return
if kwargs.get("stream") is not True or (
kwargs.get("stream") is True and "complete_streaming_response" in kwargs
):
Expand Down
Loading
Loading