diff --git a/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py b/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py index bb25e4f0626..b28b4497e7c 100644 --- a/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py +++ b/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py @@ -4,7 +4,7 @@ from litellm._uuid import uuid from datetime import datetime -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Optional from litellm._logging import verbose_proxy_logger @@ -35,14 +35,11 @@ async def check_batch_cost(self): - if not, return False - if so, return True """ - from litellm_enterprise.proxy.hooks.managed_files import ( - _PROXY_LiteLLMManagedFiles, - ) - from litellm.batches.batch_utils import ( _get_file_content_as_dictionary, calculate_batch_cost_and_usage, ) + from litellm.files.main import afile_content from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.proxy.openai_files_endpoints.common_utils import ( @@ -102,27 +99,29 @@ async def check_batch_cost(self): continue ## RETRIEVE THE BATCH JOB OUTPUT FILE - managed_files_obj = cast( - Optional[_PROXY_LiteLLMManagedFiles], - self.proxy_logging_obj.get_proxy_hook("managed_files"), - ) if ( response.status == "completed" and response.output_file_id is not None - and managed_files_obj is not None ): verbose_proxy_logger.info( f"Batch ID: {batch_id} is complete, tracking cost and usage" ) - # track cost - model_file_id_mapping = { - response.output_file_id: {model_id: response.output_file_id} - } - _file_content = await managed_files_obj.afile_content( - file_id=response.output_file_id, - litellm_parent_otel_span=None, - llm_router=self.llm_router, - model_file_id_mapping=model_file_id_mapping, + + # This background job runs as default_user_id, so going through the HTTP endpoint + # would trigger check_managed_file_id_access and get 403. Instead, extract the raw + # provider file ID and call afile_content directly with deployment credentials. + raw_output_file_id = response.output_file_id + decoded = _is_base64_encoded_unified_file_id(raw_output_file_id) + if decoded: + try: + raw_output_file_id = decoded.split("llm_output_file_id,")[1].split(";")[0] + except (IndexError, AttributeError): + pass + + credentials = self.llm_router.get_deployment_credentials_with_provider(model_id) or {} + _file_content = await afile_content( + file_id=raw_output_file_id, + **credentials, ) file_content_as_dict = _get_file_content_as_dictionary( @@ -143,11 +142,15 @@ async def check_batch_cost(self): custom_llm_provider=custom_llm_provider, ) + # Pass deployment model_info so custom batch pricing + # (input_cost_per_token_batches etc.) is used for cost calc + deployment_model_info = deployment_info.model_info.model_dump() if deployment_info.model_info else {} batch_cost, batch_usage, batch_models = ( await calculate_batch_cost_and_usage( file_content_dictionary=file_content_as_dict, custom_llm_provider=llm_provider, # type: ignore model_name=model_name, + model_info=deployment_model_info, ) ) logging_obj = LiteLLMLogging( diff --git a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py index a41b3f3bf6f..f341a1e9634 100644 --- a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py +++ b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py @@ -230,12 +230,14 @@ async def can_user_call_unified_file_id( if managed_file: return managed_file.created_by == user_id - return False + raise HTTPException( + status_code=404, + detail=f"File not found: {unified_file_id}", + ) async def can_user_call_unified_object_id( self, unified_object_id: str, user_api_key_dict: UserAPIKeyAuth ) -> bool: - ## check if the user has access to the unified object id ## check if the user has access to the unified object id user_id = user_api_key_dict.user_id managed_object = ( @@ -246,7 +248,10 @@ async def can_user_call_unified_object_id( if managed_object: return managed_object.created_by == user_id - return True # don't raise error if managed object is not found + raise HTTPException( + status_code=404, + detail=f"Object not found: {unified_object_id}", + ) async def list_user_batches( self, @@ -911,15 +916,22 @@ async def async_post_call_success_hook( ) setattr(response, file_attr, unified_file_id) - # Fetch the actual file object from the provider + # Use llm_router credentials when available. Without credentials, + # Azure and other auth-required providers return 500/401. file_object = None try: - # Use litellm to retrieve the file object from the provider - from litellm import afile_retrieve - file_object = await afile_retrieve( - custom_llm_provider=model_name.split("/")[0] if model_name and "/" in model_name else "openai", - file_id=original_file_id - ) + from litellm.proxy.proxy_server import llm_router as _llm_router + if _llm_router is not None and model_id: + _creds = _llm_router.get_deployment_credentials_with_provider(model_id) or {} + file_object = await litellm.afile_retrieve( + file_id=original_file_id, + **_creds, + ) + else: + file_object = await litellm.afile_retrieve( + custom_llm_provider=model_name.split("/")[0] if model_name and "/" in model_name else "openai", + file_id=original_file_id, + ) verbose_logger.debug( f"Successfully retrieved file object for {file_attr}={original_file_id}" ) @@ -1004,7 +1016,10 @@ async def afile_retrieve( raise Exception(f"LiteLLM Managed File object with id={file_id} not found") # Case 2: Managed file and the file object exists in the database + # The stored file_object has the raw provider ID. Replace with the unified ID + # so callers see a consistent ID (matching Case 3 which does response.id = file_id). if stored_file_object and stored_file_object.file_object: + stored_file_object.file_object.id = file_id return stored_file_object.file_object # Case 3: Managed file exists in the database but not the file object (for. e.g the batch task might not have run) diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index b67d0d86063..8f9a3c5083f 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -312,10 +312,12 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti _duration, type(_duration) ) ) # invalid _duration value + # Batch polling callbacks (check_batch_cost) don't include call_type in kwargs. + # Use .get() to avoid KeyError. await self.async_service_success_hook( service=ServiceTypes.LITELLM, duration=_duration, - call_type=kwargs["call_type"], + call_type=kwargs.get("call_type", "unknown") ) except Exception as e: raise e diff --git a/litellm/batches/batch_utils.py b/litellm/batches/batch_utils.py index 16a467e00cb..c92ab9b230e 100644 --- a/litellm/batches/batch_utils.py +++ b/litellm/batches/batch_utils.py @@ -16,14 +16,22 @@ async def calculate_batch_cost_and_usage( file_content_dictionary: List[dict], custom_llm_provider: Literal["openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"], model_name: Optional[str] = None, + model_info: Optional[dict] = None, ) -> Tuple[float, Usage, List[str]]: """ - Calculate the cost and usage of a batch + Calculate the cost and usage of a batch. + + Args: + model_info: Optional deployment-level model info with custom batch + pricing. Threaded through to batch_cost_calculator so that + deployment-specific pricing (e.g. input_cost_per_token_batches) + is used instead of the global cost map. """ batch_cost = _batch_cost_calculator( custom_llm_provider=custom_llm_provider, file_content_dictionary=file_content_dictionary, model_name=model_name, + model_info=model_info, ) batch_usage = _get_batch_job_total_usage_from_file_content( file_content_dictionary=file_content_dictionary, @@ -94,6 +102,7 @@ def _batch_cost_calculator( file_content_dictionary: List[dict], custom_llm_provider: Literal["openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"] = "openai", model_name: Optional[str] = None, + model_info: Optional[dict] = None, ) -> float: """ Calculate the cost of a batch based on the output file id @@ -108,6 +117,7 @@ def _batch_cost_calculator( total_cost = _get_batch_job_cost_from_file_content( file_content_dictionary=file_content_dictionary, custom_llm_provider=custom_llm_provider, + model_info=model_info, ) verbose_logger.debug("total_cost=%s", total_cost) return total_cost @@ -290,10 +300,13 @@ def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]: def _get_batch_job_cost_from_file_content( file_content_dictionary: List[dict], custom_llm_provider: Literal["openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"] = "openai", + model_info: Optional[dict] = None, ) -> float: """ Get the cost of a batch job from the file content """ + from litellm.cost_calculator import batch_cost_calculator + try: total_cost: float = 0.0 # parse the file content as json @@ -303,11 +316,22 @@ def _get_batch_job_cost_from_file_content( for _item in file_content_dictionary: if _batch_response_was_successful(_item): _response_body = _get_response_from_batch_job_output_file(_item) - total_cost += litellm.completion_cost( - completion_response=_response_body, - custom_llm_provider=custom_llm_provider, - call_type=CallTypes.aretrieve_batch.value, - ) + if model_info is not None: + usage = _get_batch_job_usage_from_response_body(_response_body) + model = _response_body.get("model", "") + prompt_cost, completion_cost = batch_cost_calculator( + usage=usage, + model=model, + custom_llm_provider=custom_llm_provider, + model_info=model_info, + ) + total_cost += prompt_cost + completion_cost + else: + total_cost += litellm.completion_cost( + completion_response=_response_body, + custom_llm_provider=custom_llm_provider, + call_type=CallTypes.aretrieve_batch.value, + ) verbose_logger.debug("total_cost=%s", total_cost) return total_cost except Exception as e: diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 4ea22dbd90f..48dfec2e8c2 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -1892,9 +1892,16 @@ def batch_cost_calculator( usage: Usage, model: str, custom_llm_provider: Optional[str] = None, + model_info: Optional[dict] = None, ) -> Tuple[float, float]: """ - Calculate the cost of a batch job + Calculate the cost of a batch job. + + Args: + model_info: Optional deployment-level model info containing custom + batch pricing (e.g. input_cost_per_token_batches). When provided, + skips the global litellm.get_model_info() lookup so that + deployment-specific pricing is used. """ _, custom_llm_provider, _, _ = litellm.get_llm_provider( @@ -1907,12 +1914,13 @@ def batch_cost_calculator( custom_llm_provider, ) - try: - model_info: Optional[ModelInfo] = litellm.get_model_info( - model=model, custom_llm_provider=custom_llm_provider - ) - except Exception: - model_info = None + if model_info is None: + try: + model_info = litellm.get_model_info( + model=model, custom_llm_provider=custom_llm_provider + ) + except Exception: + model_info = None if not model_info: return 0.0, 0.0 diff --git a/litellm/integrations/s3_v2.py b/litellm/integrations/s3_v2.py index 534b85e4752..e0932fc3373 100644 --- a/litellm/integrations/s3_v2.py +++ b/litellm/integrations/s3_v2.py @@ -247,8 +247,14 @@ async def _async_log_event_base(self, kwargs, response_obj, start_time, end_time standard_logging_payload=kwargs.get("standard_logging_object", None), ) + # afile_delete and other non-model call types never produce a standard_logging_object, + # so s3_batch_logging_element is None. Skip gracefully instead of raising ValueError. if s3_batch_logging_element is None: - raise ValueError("s3_batch_logging_element is None") + verbose_logger.debug( + "s3 Logging - skipping event, no standard_logging_object for call_type=%s", + kwargs.get("call_type", "unknown"), + ) + return verbose_logger.debug( "\ns3 Logger - Logging payload = %s", s3_batch_logging_element diff --git a/litellm/proxy/batches_endpoints/endpoints.py b/litellm/proxy/batches_endpoints/endpoints.py index 06800cb4524..783ac9d6f19 100644 --- a/litellm/proxy/batches_endpoints/endpoints.py +++ b/litellm/proxy/batches_endpoints/endpoints.py @@ -29,6 +29,7 @@ get_models_from_unified_file_id, get_original_file_id, prepare_data_with_credentials, + resolve_input_file_id_to_unified, update_batch_in_database, ) from litellm.proxy.utils import handle_exception_on_proxy, is_known_model @@ -377,6 +378,11 @@ async def retrieve_batch( response = await proxy_logging_obj.post_call_success_hook( data=data, user_api_key_dict=user_api_key_dict, response=response ) + + # async_post_call_success_hook replaces batch.id and output_file_id with unified IDs + # but not input_file_id. Resolve raw provider ID to unified ID. + if unified_batch_id: + await resolve_input_file_id_to_unified(response, prisma_client) asyncio.create_task( proxy_logging_obj.update_request_status( @@ -479,6 +485,11 @@ async def retrieve_batch( data=data, user_api_key_dict=user_api_key_dict, response=response ) + # Fix: bug_feb14_batch_retrieve_returns_raw_input_file_id + # Resolve raw provider input_file_id to unified ID. + if unified_batch_id: + await resolve_input_file_id_to_unified(response, prisma_client) + ### ALERTING ### asyncio.create_task( proxy_logging_obj.update_request_status( diff --git a/litellm/proxy/hooks/batch_rate_limiter.py b/litellm/proxy/hooks/batch_rate_limiter.py index 45b1bd8653f..5bebcc92072 100644 --- a/litellm/proxy/hooks/batch_rate_limiter.py +++ b/litellm/proxy/hooks/batch_rate_limiter.py @@ -259,9 +259,10 @@ async def count_input_file_usage( from litellm.proxy.openai_files_endpoints.common_utils import ( _is_base64_encoded_unified_file_id, ) + # Managed files require bypassing the HTTP endpoint (which runs access-check hooks) + # and calling the managed files hook directly with the user's credentials. is_managed_file = _is_base64_encoded_unified_file_id(file_id) if is_managed_file and user_api_key_dict is not None: - # For managed files, use the managed files hook directly file_content = await self._fetch_managed_file_content( file_id=file_id, user_api_key_dict=user_api_key_dict, diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index 37b79e6d065..d903ce0d9d7 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -202,6 +202,14 @@ async def _PROXY_track_cost_callback( max_budget=end_user_max_budget, ) else: + # Non-model call types (health checks, afile_delete) have no model or standard_logging_object. + # Use .get() for "stream" to avoid KeyError on health checks. + if sl_object is None and not kwargs.get("model"): + verbose_proxy_logger.warning( + "Cost tracking - skipping, no standard_logging_object and no model for call_type=%s", + kwargs.get("call_type", "unknown"), + ) + return if kwargs.get("stream") is not True or ( kwargs.get("stream") is True and "complete_streaming_response" in kwargs ): diff --git a/litellm/proxy/openai_files_endpoints/common_utils.py b/litellm/proxy/openai_files_endpoints/common_utils.py index f67dc5e2aaa..75f64cddf59 100644 --- a/litellm/proxy/openai_files_endpoints/common_utils.py +++ b/litellm/proxy/openai_files_endpoints/common_utils.py @@ -644,6 +644,28 @@ def _extract_model_param(request: "Request", request_body: dict) -> Optional[str # ============================================================================ +async def resolve_input_file_id_to_unified(response, prisma_client) -> None: + """ + If the batch response contains a raw provider input_file_id (not already a + unified ID), look up the corresponding unified file ID from the managed file + table and replace it in-place. + """ + if ( + hasattr(response, "input_file_id") + and response.input_file_id + and not _is_base64_encoded_unified_file_id(response.input_file_id) + and prisma_client + ): + try: + managed_file = await prisma_client.db.litellm_managedfiletable.find_first( + where={"flat_model_file_ids": {"has": response.input_file_id}} + ) + if managed_file: + response.input_file_id = managed_file.unified_file_id + except Exception: + pass + + async def get_batch_from_database( batch_id: str, unified_batch_id: Union[str, Literal[False]], @@ -687,6 +709,9 @@ async def get_batch_from_database( batch_data = json.loads(db_batch_object.file_object) if isinstance(db_batch_object.file_object, str) else db_batch_object.file_object response = LiteLLMBatch(**batch_data) response.id = batch_id + + # The stored batch object has the raw provider input_file_id. Resolve to unified ID. + await resolve_input_file_id_to_unified(response, prisma_client) verbose_proxy_logger.debug( f"Retrieved batch {batch_id} from ManagedObjectTable with status={response.status}" diff --git a/tests/batches_tests/test_batch_custom_pricing.py b/tests/batches_tests/test_batch_custom_pricing.py new file mode 100644 index 00000000000..8bc1bd5a307 --- /dev/null +++ b/tests/batches_tests/test_batch_custom_pricing.py @@ -0,0 +1,131 @@ +""" +Test that batch cost calculation uses custom deployment-level pricing +when model_info is provided. + +Reproduces the bug where `input_cost_per_token_batches` / +`output_cost_per_token_batches` set on a proxy deployment's model_info +are ignored by the batch cost pipeline because they are never threaded +through to `batch_cost_calculator`. +""" + +import pytest + +from litellm.batches.batch_utils import ( + _batch_cost_calculator, + _get_batch_job_cost_from_file_content, + calculate_batch_cost_and_usage, +) +from litellm.cost_calculator import batch_cost_calculator +from litellm.types.utils import Usage + + +# --- helpers --- + +def _make_batch_output_line(prompt_tokens: int = 10, completion_tokens: int = 5): + """Return a single successful batch output line (OpenAI JSONL format).""" + return { + "id": "batch_req_1", + "custom_id": "req-1", + "response": { + "status_code": 200, + "body": { + "id": "chatcmpl-test", + "object": "chat.completion", + "model": "fake-batch-model", + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Hello"}, + "finish_reason": "stop", + } + ], + }, + }, + "error": None, + } + + +CUSTOM_MODEL_INFO = { + "input_cost_per_token_batches": 0.00125, + "output_cost_per_token_batches": 0.005, +} + + +# --- tests --- + + +def test_batch_cost_calculator_uses_custom_model_info(): + """batch_cost_calculator should use model_info override when provided.""" + usage = Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15) + + prompt_cost, completion_cost = batch_cost_calculator( + usage=usage, + model="fake-batch-model", + custom_llm_provider="openai", + model_info=CUSTOM_MODEL_INFO, + ) + + expected_prompt = 10 * 0.00125 + expected_completion = 5 * 0.005 + assert prompt_cost == pytest.approx(expected_prompt), ( + f"Expected prompt cost {expected_prompt}, got {prompt_cost}" + ) + assert completion_cost == pytest.approx(expected_completion), ( + f"Expected completion cost {expected_completion}, got {completion_cost}" + ) + + +def test_get_batch_job_cost_from_file_content_uses_custom_model_info(): + """_get_batch_job_cost_from_file_content should thread model_info to completion_cost.""" + file_content = [_make_batch_output_line(prompt_tokens=10, completion_tokens=5)] + + cost = _get_batch_job_cost_from_file_content( + file_content_dictionary=file_content, + custom_llm_provider="openai", + model_info=CUSTOM_MODEL_INFO, + ) + + expected = (10 * 0.00125) + (5 * 0.005) + assert cost == pytest.approx(expected), ( + f"Expected total cost {expected}, got {cost}" + ) + + +def test_batch_cost_calculator_func_uses_custom_model_info(): + """_batch_cost_calculator should thread model_info.""" + file_content = [_make_batch_output_line(prompt_tokens=10, completion_tokens=5)] + + cost = _batch_cost_calculator( + file_content_dictionary=file_content, + custom_llm_provider="openai", + model_info=CUSTOM_MODEL_INFO, + ) + + expected = (10 * 0.00125) + (5 * 0.005) + assert cost == pytest.approx(expected), ( + f"Expected total cost {expected}, got {cost}" + ) + + +@pytest.mark.asyncio +async def test_calculate_batch_cost_and_usage_uses_custom_model_info(): + """calculate_batch_cost_and_usage should thread model_info.""" + file_content = [_make_batch_output_line(prompt_tokens=10, completion_tokens=5)] + + batch_cost, batch_usage, batch_models = await calculate_batch_cost_and_usage( + file_content_dictionary=file_content, + custom_llm_provider="openai", + model_info=CUSTOM_MODEL_INFO, + ) + + expected = (10 * 0.00125) + (5 * 0.005) + assert batch_cost == pytest.approx(expected), ( + f"Expected total cost {expected}, got {batch_cost}" + ) + assert batch_usage.prompt_tokens == 10 + assert batch_usage.completion_tokens == 5 diff --git a/tests/test_litellm/enterprise/proxy/test_afile_retrieve_returns_unified_id.py b/tests/test_litellm/enterprise/proxy/test_afile_retrieve_returns_unified_id.py new file mode 100644 index 00000000000..7040aef73e5 --- /dev/null +++ b/tests/test_litellm/enterprise/proxy/test_afile_retrieve_returns_unified_id.py @@ -0,0 +1,67 @@ +""" +Test that managed_files.afile_retrieve returns the unified file ID, not the +raw provider file ID, when file_object is already stored in the database. + +Bug: managed_files.py Case 2 returns stored_file_object.file_object directly +without replacing .id with the unified ID. Case 3 (fetch from provider) does +it correctly at line 1028. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from litellm.proxy._types import LiteLLM_ManagedFileTable +from litellm.types.llms.openai import OpenAIFileObject + + +def _make_managed_files_instance(): + from litellm_enterprise.proxy.hooks.managed_files import ( + _PROXY_LiteLLMManagedFiles, + ) + + instance = _PROXY_LiteLLMManagedFiles( + internal_usage_cache=MagicMock(), + prisma_client=MagicMock(), + ) + return instance + + +@pytest.mark.asyncio +async def test_should_return_unified_id_when_file_object_exists_in_db(): + """ + When get_unified_file_id returns a stored file_object (Case 2), + afile_retrieve must set .id to the unified file ID before returning. + """ + unified_id = "bGl0ZWxsbV9wcm94eTp1bmlmaWVkX291dHB1dF9maWxl" + raw_provider_id = "batch_20260214-output-file-1" + + stored = LiteLLM_ManagedFileTable( + unified_file_id=unified_id, + file_object=OpenAIFileObject( + id=raw_provider_id, + bytes=489, + created_at=1700000000, + filename="batch_output.jsonl", + object="file", + purpose="batch_output", + status="processed", + ), + model_mappings={"model-abc": raw_provider_id}, + flat_model_file_ids=[raw_provider_id], + created_by="test-user", + updated_by="test-user", + ) + + managed_files = _make_managed_files_instance() + managed_files.get_unified_file_id = AsyncMock(return_value=stored) + + result = await managed_files.afile_retrieve( + file_id=unified_id, + litellm_parent_otel_span=None, + llm_router=None, + ) + + assert result.id == unified_id, ( + f"afile_retrieve should return the unified ID '{unified_id}', " + f"but got raw provider ID '{result.id}'" + ) diff --git a/tests/test_litellm/enterprise/proxy/test_batch_retrieve_input_file_id.py b/tests/test_litellm/enterprise/proxy/test_batch_retrieve_input_file_id.py new file mode 100644 index 00000000000..6e9c3c0354b --- /dev/null +++ b/tests/test_litellm/enterprise/proxy/test_batch_retrieve_input_file_id.py @@ -0,0 +1,75 @@ +""" +Test that batch retrieve endpoint resolves raw input_file_id to the +unified managed file ID before returning. + +Bug: After batch completion, batches.retrieve returns the raw provider +input_file_id instead of the LiteLLM unified ID. +""" + +import base64 +import json + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from litellm.proxy.openai_files_endpoints.common_utils import ( + _is_base64_encoded_unified_file_id, +) + + +DECODED_UNIFIED_INPUT_FILE_ID = "litellm_proxy:application/octet-stream;unified_id,test-uuid;target_model_names,azure-gpt-4" +B64_UNIFIED_INPUT_FILE_ID = base64.urlsafe_b64encode(DECODED_UNIFIED_INPUT_FILE_ID.encode()).decode().rstrip("=") +RAW_INPUT_FILE_ID = "file-raw-provider-abc123" + +DECODED_UNIFIED_BATCH_ID = "litellm_proxy;model_id:model-xyz;llm_batch_id:batch-123" +B64_UNIFIED_BATCH_ID = base64.urlsafe_b64encode(DECODED_UNIFIED_BATCH_ID.encode()).decode().rstrip("=") + + +@pytest.mark.asyncio +async def test_should_resolve_raw_input_file_id_to_unified(): + """ + When a completed batch has a raw input_file_id and the managed file table + contains a record for that raw ID, the retrieve endpoint should resolve + it to the unified file ID. + """ + unified_batch_id = _is_base64_encoded_unified_file_id(B64_UNIFIED_BATCH_ID) + assert unified_batch_id, "Test setup: batch_id should decode as unified" + + from litellm.types.utils import LiteLLMBatch + + batch_data = { + "id": B64_UNIFIED_BATCH_ID, + "completion_window": "24h", + "created_at": 1700000000, + "endpoint": "/v1/chat/completions", + "input_file_id": RAW_INPUT_FILE_ID, + "object": "batch", + "status": "completed", + "output_file_id": "file-output-xyz", + } + + mock_db_object = MagicMock() + mock_db_object.file_object = json.dumps(batch_data) + + mock_managed_file = MagicMock() + mock_managed_file.unified_file_id = B64_UNIFIED_INPUT_FILE_ID + + mock_prisma = MagicMock() + mock_prisma.db.litellm_managedobjecttable.find_first = AsyncMock(return_value=mock_db_object) + mock_prisma.db.litellm_managedfiletable.find_first = AsyncMock(return_value=mock_managed_file) + + from litellm.proxy.openai_files_endpoints.common_utils import get_batch_from_database + + _, response = await get_batch_from_database( + batch_id=B64_UNIFIED_BATCH_ID, + unified_batch_id=unified_batch_id, + managed_files_obj=MagicMock(), + prisma_client=mock_prisma, + verbose_proxy_logger=MagicMock(), + ) + + assert response is not None, "Batch should be found in DB" + assert response.input_file_id == B64_UNIFIED_INPUT_FILE_ID, ( + f"input_file_id should be unified '{B64_UNIFIED_INPUT_FILE_ID}', " + f"got raw '{response.input_file_id}'" + ) diff --git a/tests/test_litellm/enterprise/proxy/test_batch_retrieve_returns_unified_input_file_id.py b/tests/test_litellm/enterprise/proxy/test_batch_retrieve_returns_unified_input_file_id.py new file mode 100644 index 00000000000..420f5f9789c --- /dev/null +++ b/tests/test_litellm/enterprise/proxy/test_batch_retrieve_returns_unified_input_file_id.py @@ -0,0 +1,124 @@ +""" +Test that get_batch_from_database resolves raw input_file_id to the +unified/managed file ID when reading a batch from the database. + +Bug: The batch retrieve path stores the raw provider input_file_id in the +DB (via async_post_call_success_hook on the retrieve endpoint). When the +batch is later read from DB, get_batch_from_database returns the raw ID +without resolving it to the unified ID. +""" + +import json +import pytest +from typing import Optional +from unittest.mock import AsyncMock, MagicMock + +from litellm.proxy.openai_files_endpoints.common_utils import get_batch_from_database + + +def _mock_prisma(batch_json: str, managed_file_record=None): + """Create a mock prisma client with canned responses.""" + prisma = MagicMock() + + batch_db_record = MagicMock() + batch_db_record.file_object = batch_json + + prisma.db.litellm_managedobjecttable.find_first = AsyncMock( + return_value=batch_db_record + ) + + prisma.db.litellm_managedfiletable.find_first = AsyncMock( + return_value=managed_file_record + ) + + return prisma + + +@pytest.mark.asyncio +async def test_should_resolve_raw_input_file_id_to_unified_id(): + """ + When input_file_id in the stored batch is a raw provider ID, + get_batch_from_database must look up the unified ID from the + managed files table. + """ + unified_batch_id = "bGl0ZWxsbV9wcm94eTpiYXRjaF9pZA" + unified_input_file_id = "bGl0ZWxsbV9wcm94eTp1bmlmaWVkX2lucHV0" + raw_input_file_id = "file-abc123-raw" + + batch_data = { + "id": "batch-raw-123", + "completion_window": "24h", + "created_at": 1700000000, + "endpoint": "/v1/chat/completions", + "input_file_id": raw_input_file_id, + "object": "batch", + "status": "completed", + "output_file_id": "file-output-raw", + } + + managed_file_record = MagicMock() + managed_file_record.unified_file_id = unified_input_file_id + + prisma = _mock_prisma( + batch_json=json.dumps(batch_data), + managed_file_record=managed_file_record, + ) + + _, response = await get_batch_from_database( + batch_id=unified_batch_id, + unified_batch_id="decoded_unified_batch_id", + managed_files_obj=MagicMock(), + prisma_client=prisma, + verbose_proxy_logger=MagicMock(), + ) + + assert response is not None + assert response.input_file_id == unified_input_file_id, ( + f"input_file_id should be resolved to '{unified_input_file_id}', " + f"got raw: '{response.input_file_id}'" + ) + + prisma.db.litellm_managedfiletable.find_first.assert_called_once_with( + where={"flat_model_file_ids": {"has": raw_input_file_id}} + ) + + +@pytest.mark.asyncio +async def test_should_preserve_already_managed_input_file_id(): + """ + When input_file_id is already a managed/unified ID, it should + not be modified. + """ + import base64 + + unified_batch_id = "bGl0ZWxsbV9wcm94eTpiYXRjaF9pZA" + decoded_unified = "litellm_proxy:application/octet-stream;unified_id,test-123" + base64_input_file_id = base64.urlsafe_b64encode(decoded_unified.encode()).decode().rstrip("=") + + batch_data = { + "id": "batch-raw-123", + "completion_window": "24h", + "created_at": 1700000000, + "endpoint": "/v1/chat/completions", + "input_file_id": base64_input_file_id, + "object": "batch", + "status": "completed", + } + + prisma = _mock_prisma(batch_json=json.dumps(batch_data)) + + _, response = await get_batch_from_database( + batch_id=unified_batch_id, + unified_batch_id="decoded_unified_batch_id", + managed_files_obj=MagicMock(), + prisma_client=prisma, + verbose_proxy_logger=MagicMock(), + ) + + assert response is not None + assert response.input_file_id == base64_input_file_id, ( + f"input_file_id was already managed, should be preserved as '{base64_input_file_id}', " + f"got: '{response.input_file_id}'" + ) + + prisma.db.litellm_managedfiletable.find_first.assert_not_called() diff --git a/tests/test_litellm/enterprise/proxy/test_deleted_file_returns_403_not_404.py b/tests/test_litellm/enterprise/proxy/test_deleted_file_returns_403_not_404.py new file mode 100644 index 00000000000..7ad564dc8f9 --- /dev/null +++ b/tests/test_litellm/enterprise/proxy/test_deleted_file_returns_403_not_404.py @@ -0,0 +1,119 @@ +""" +Regression test: deleted managed files should return 404, not 403. + +When a managed file's DB record has been deleted, can_user_call_unified_file_id() +raises HTTPException(404) directly — rather than returning True (which would +weaken access control) or False (which would cause a misleading 403). +""" + +import base64 + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from fastapi import HTTPException + +from litellm.proxy._types import UserAPIKeyAuth + + +def _make_user_api_key_dict(user_id: str) -> UserAPIKeyAuth: + return UserAPIKeyAuth( + api_key="sk-test", + user_id=user_id, + parent_otel_span=None, + ) + + +def _make_unified_file_id() -> str: + raw = "litellm_proxy:application/octet-stream;unified_id,test-deleted-file;target_model_names,azure-gpt-4" + return base64.b64encode(raw.encode()).decode() + + +def _make_managed_files_with_no_db_record(): + """Create a _PROXY_LiteLLMManagedFiles where the DB returns None (file was deleted).""" + from litellm_enterprise.proxy.hooks.managed_files import ( + _PROXY_LiteLLMManagedFiles, + ) + + mock_prisma = MagicMock() + mock_prisma.db.litellm_managedfiletable.find_first = AsyncMock(return_value=None) + + return _PROXY_LiteLLMManagedFiles( + internal_usage_cache=MagicMock(), + prisma_client=mock_prisma, + ) + + +@pytest.mark.asyncio +async def test_should_raise_404_for_deleted_file(): + """ + When a managed file record has been deleted from the DB, + check_managed_file_id_access should raise 404 (not 403). + """ + unified_file_id = _make_unified_file_id() + managed_files = _make_managed_files_with_no_db_record() + user = _make_user_api_key_dict("any-user") + data = {"file_id": unified_file_id} + + with pytest.raises(HTTPException) as exc_info: + await managed_files.check_managed_file_id_access(data, user) + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_should_allow_owner_access_when_record_exists(): + """Baseline: file owner can access their own file.""" + from litellm_enterprise.proxy.hooks.managed_files import ( + _PROXY_LiteLLMManagedFiles, + ) + + unified_file_id = _make_unified_file_id() + + mock_db_record = MagicMock() + mock_db_record.created_by = "user-A" + + mock_prisma = MagicMock() + mock_prisma.db.litellm_managedfiletable.find_first = AsyncMock( + return_value=mock_db_record + ) + + managed_files = _PROXY_LiteLLMManagedFiles( + internal_usage_cache=MagicMock(), + prisma_client=mock_prisma, + ) + + user = _make_user_api_key_dict("user-A") + data = {"file_id": unified_file_id} + + result = await managed_files.check_managed_file_id_access(data, user) + assert result is True + + +@pytest.mark.asyncio +async def test_should_block_different_user_when_record_exists(): + """Baseline: different user cannot access another user's file.""" + from litellm_enterprise.proxy.hooks.managed_files import ( + _PROXY_LiteLLMManagedFiles, + ) + + unified_file_id = _make_unified_file_id() + + mock_db_record = MagicMock() + mock_db_record.created_by = "user-A" + + mock_prisma = MagicMock() + mock_prisma.db.litellm_managedfiletable.find_first = AsyncMock( + return_value=mock_db_record + ) + + managed_files = _PROXY_LiteLLMManagedFiles( + internal_usage_cache=MagicMock(), + prisma_client=mock_prisma, + ) + + user = _make_user_api_key_dict("user-B") + data = {"file_id": unified_file_id} + + with pytest.raises(HTTPException) as exc_info: + await managed_files.check_managed_file_id_access(data, user) + assert exc_info.value.status_code == 403 diff --git a/tests/test_litellm/enterprise/proxy/test_managed_files_access_check.py b/tests/test_litellm/enterprise/proxy/test_managed_files_access_check.py new file mode 100644 index 00000000000..2db5a2214cb --- /dev/null +++ b/tests/test_litellm/enterprise/proxy/test_managed_files_access_check.py @@ -0,0 +1,200 @@ +""" +Tests for managed files access control in batch polling context. + +Regression test for: batch polling job running as default_user_id gets 403 +when trying to access managed files created by a real user. + +The fix (Option C) makes check_batch_cost call litellm.afile_content directly +with deployment credentials, bypassing the managed files access-control hooks. +""" + +import base64 +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi import HTTPException + +from litellm.proxy._types import UserAPIKeyAuth + + +def _make_user_api_key_dict(user_id: str) -> UserAPIKeyAuth: + return UserAPIKeyAuth( + api_key="sk-test", + user_id=user_id, + parent_otel_span=None, + ) + + +def _make_unified_file_id() -> str: + """Create a base64-encoded unified file ID that passes _is_base64_encoded_unified_file_id.""" + raw = "litellm_proxy:application/octet-stream;unified_id,test-123;target_model_names,azure-gpt-4" + return base64.b64encode(raw.encode()).decode() + + +def _make_managed_files_instance(file_created_by: str, unified_file_id: str): + """Create a _PROXY_LiteLLMManagedFiles with a mocked DB that returns a file owned by file_created_by.""" + from litellm_enterprise.proxy.hooks.managed_files import ( + _PROXY_LiteLLMManagedFiles, + ) + + mock_db_record = MagicMock() + mock_db_record.created_by = file_created_by + + mock_prisma = MagicMock() + mock_prisma.db.litellm_managedfiletable.find_first = AsyncMock( + return_value=mock_db_record + ) + + instance = _PROXY_LiteLLMManagedFiles( + internal_usage_cache=MagicMock(), + prisma_client=mock_prisma, + ) + return instance + + +# --- Access control unit tests (document existing behavior) --- + + +@pytest.mark.asyncio +async def test_should_allow_file_owner_access(): + """File owner can access their own file — baseline sanity check.""" + unified_file_id = _make_unified_file_id() + managed_files = _make_managed_files_instance( + file_created_by="user-A", + unified_file_id=unified_file_id, + ) + user = _make_user_api_key_dict("user-A") + data = {"file_id": unified_file_id} + + result = await managed_files.check_managed_file_id_access(data, user) + assert result is True + + +@pytest.mark.asyncio +async def test_should_block_different_user_access(): + """A different regular user cannot access another user's file — correct behavior.""" + unified_file_id = _make_unified_file_id() + managed_files = _make_managed_files_instance( + file_created_by="user-A", + unified_file_id=unified_file_id, + ) + user = _make_user_api_key_dict("user-B") + data = {"file_id": unified_file_id} + + with pytest.raises(HTTPException) as exc_info: + await managed_files.check_managed_file_id_access(data, user) + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_should_block_default_user_id_access(): + """ + default_user_id is correctly blocked by the access check. + This documents the existing behavior that the Option C fix works around. + """ + unified_file_id = _make_unified_file_id() + managed_files = _make_managed_files_instance( + file_created_by="user-A", + unified_file_id=unified_file_id, + ) + system_user = _make_user_api_key_dict("default_user_id") + data = {"file_id": unified_file_id} + + with pytest.raises(HTTPException) as exc_info: + await managed_files.check_managed_file_id_access(data, system_user) + assert exc_info.value.status_code == 403 + + +# --- Option C fix test: check_batch_cost bypasses managed files hook --- + + +@pytest.mark.asyncio +async def test_check_batch_cost_should_call_afile_content_directly_with_credentials(): + """ + check_batch_cost should call litellm.afile_content directly with deployment + credentials, bypassing managed_files_obj.afile_content and its access-control + hooks. This avoids the 403 that occurs when the background job runs as + default_user_id. + """ + from litellm_enterprise.proxy.common_utils.check_batch_cost import CheckBatchCost + + # Build a unified object ID in the expected format: + # litellm_proxy;model_id:{};llm_batch_id:{};llm_output_file_id:{} + unified_raw = "litellm_proxy;model_id:model-deploy-xyz;llm_batch_id:batch-123;llm_output_file_id:file-raw-output" + unified_object_id = base64.b64encode(unified_raw.encode()).decode() + + # Mock a pending job from the DB + mock_job = MagicMock() + mock_job.unified_object_id = unified_object_id + mock_job.created_by = "user-A" + mock_job.id = "job-1" + + # Mock prisma + mock_prisma = MagicMock() + mock_prisma.db.litellm_managedobjecttable.find_many = AsyncMock( + return_value=[mock_job] + ) + mock_prisma.db.litellm_managedobjecttable.update_many = AsyncMock() + + # Mock proxy_logging_obj — should NOT be called for file content + mock_proxy_logging = MagicMock() + mock_managed_files_hook = MagicMock() + mock_managed_files_hook.afile_content = AsyncMock() + mock_proxy_logging.get_proxy_hook = MagicMock(return_value=mock_managed_files_hook) + + # Mock the batch response (completed, with output file) + from litellm.types.utils import LiteLLMBatch + batch_response = LiteLLMBatch( + id="batch-123", + completion_window="24h", + created_at=1700000000, + endpoint="/v1/chat/completions", + input_file_id="file-input", + object="batch", + status="completed", + output_file_id="file-raw-output", + ) + + # Mock router + mock_router = MagicMock() + mock_router.aretrieve_batch = AsyncMock(return_value=batch_response) + mock_router.get_deployment_credentials_with_provider = MagicMock( + return_value={ + "api_key": "test-key", + "api_base": "https://test.azure.com/", + "custom_llm_provider": "azure", + } + ) + + mock_deployment = MagicMock() + mock_deployment.litellm_params.custom_llm_provider = "azure" + mock_deployment.litellm_params.model = "azure/gpt-4" + mock_router.get_deployment = MagicMock(return_value=mock_deployment) + + checker = CheckBatchCost( + proxy_logging_obj=mock_proxy_logging, + prisma_client=mock_prisma, + llm_router=mock_router, + ) + + mock_file_content = MagicMock() + mock_file_content.content = b'{"id":"req-1","response":{"status_code":200,"body":{"id":"cmpl-1","object":"chat.completion","created":1700000000,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"hi"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}}}\n' + + with patch( + "litellm.files.main.afile_content", + new_callable=AsyncMock, + return_value=mock_file_content, + ) as mock_direct_afile_content: + await checker.check_batch_cost() + + # afile_content should be called directly (not through managed_files_obj) + mock_direct_afile_content.assert_called_once() + call_kwargs = mock_direct_afile_content.call_args.kwargs + + assert call_kwargs.get("api_key") == "test-key", ( + f"afile_content should receive api_key from deployment credentials. " + f"Got: {call_kwargs}" + ) + + # managed_files_obj.afile_content should NOT have been called + mock_managed_files_hook.afile_content.assert_not_called() diff --git a/tests/test_litellm/enterprise/proxy/test_managed_files_hook.py b/tests/test_litellm/enterprise/proxy/test_managed_files_hook.py new file mode 100644 index 00000000000..9526304aff0 --- /dev/null +++ b/tests/test_litellm/enterprise/proxy/test_managed_files_hook.py @@ -0,0 +1,167 @@ +""" +Tests for enterprise/litellm_enterprise/proxy/hooks/managed_files.py + +Regression test for afile_retrieve called without credentials in +async_post_call_success_hook when processing completed batch responses. +""" + +import pytest +from typing import Optional +from unittest.mock import AsyncMock, MagicMock, patch + +from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.llms.openai import OpenAIFileObject +from litellm.types.utils import LiteLLMBatch + + +def _make_file_object(file_id: str = "file-output-abc") -> OpenAIFileObject: + return OpenAIFileObject( + id=file_id, + bytes=100, + created_at=1700000000, + filename="output.jsonl", + object="file", + purpose="batch_output", + status="processed", + ) + + +def _make_batch_response( + batch_id: str = "batch-123", + output_file_id: Optional[str] = "file-output-abc", + status: str = "completed", + model_id: str = "model-deploy-xyz", + model_name: str = "azure/gpt-4", +) -> LiteLLMBatch: + """Create a LiteLLMBatch response with hidden params set as the router would.""" + batch = LiteLLMBatch( + id=batch_id, + completion_window="24h", + created_at=1700000000, + endpoint="/v1/chat/completions", + input_file_id="file-input-abc", + object="batch", + status=status, + output_file_id=output_file_id, + ) + batch._hidden_params = { + "unified_file_id": "some-unified-id", + "unified_batch_id": "some-unified-batch-id", + "model_id": model_id, + "model_name": model_name, + } + return batch + + +def _make_user_api_key_dict() -> UserAPIKeyAuth: + return UserAPIKeyAuth( + api_key="sk-test", + user_id="test-user", + parent_otel_span=None, + ) + + +def _make_managed_files_instance(): + """Create a _PROXY_LiteLLMManagedFiles with storage methods mocked out.""" + from litellm_enterprise.proxy.hooks.managed_files import ( + _PROXY_LiteLLMManagedFiles, + ) + + mock_cache = MagicMock() + mock_prisma = MagicMock() + + instance = _PROXY_LiteLLMManagedFiles( + internal_usage_cache=mock_cache, + prisma_client=mock_prisma, + ) + instance.store_unified_file_id = AsyncMock() + instance.store_unified_object_id = AsyncMock() + return instance + + +@pytest.mark.asyncio +async def test_should_pass_credentials_to_afile_retrieve(): + """ + When async_post_call_success_hook processes a completed batch with an output_file_id, + it calls afile_retrieve to fetch file metadata. It must pass credentials from the + router deployment, not just custom_llm_provider and file_id. + + Regression test for: managed_files.py:919 calling afile_retrieve without api_key/api_base. + """ + managed_files = _make_managed_files_instance() + batch_response = _make_batch_response( + model_id="model-deploy-xyz", + model_name="azure/gpt-4", + output_file_id="file-output-abc", + ) + user_api_key_dict = _make_user_api_key_dict() + + mock_credentials = { + "api_key": "test-azure-key", + "api_base": "https://my-azure.openai.azure.com/", + "api_version": "2025-03-01-preview", + "custom_llm_provider": "azure", + } + + mock_router = MagicMock() + mock_router.get_deployment_credentials_with_provider = MagicMock( + return_value=mock_credentials + ) + + mock_afile_retrieve = AsyncMock(return_value=_make_file_object("file-output-abc")) + + with patch( + "litellm.afile_retrieve", mock_afile_retrieve + ), patch( + "litellm.proxy.proxy_server.llm_router", mock_router + ): + await managed_files.async_post_call_success_hook( + data={}, + user_api_key_dict=user_api_key_dict, + response=batch_response, + ) + + mock_afile_retrieve.assert_called() + call_kwargs = mock_afile_retrieve.call_args + + assert call_kwargs.kwargs.get("api_key") == "test-azure-key", ( + f"afile_retrieve must receive api_key from router credentials. " + f"Got kwargs: {call_kwargs.kwargs}" + ) + assert call_kwargs.kwargs.get("api_base") == "https://my-azure.openai.azure.com/", ( + f"afile_retrieve must receive api_base from router credentials. " + f"Got kwargs: {call_kwargs.kwargs}" + ) + + +@pytest.mark.asyncio +async def test_should_fallback_when_no_router(): + """ + When llm_router is not available, afile_retrieve should still be called + with the fallback behavior (custom_llm_provider extracted from model_name). + """ + managed_files = _make_managed_files_instance() + batch_response = _make_batch_response( + model_id="model-deploy-xyz", + model_name="azure/gpt-4", + output_file_id="file-output-abc", + ) + user_api_key_dict = _make_user_api_key_dict() + + mock_afile_retrieve = AsyncMock(return_value=_make_file_object("file-output-abc")) + + with patch( + "litellm.afile_retrieve", mock_afile_retrieve + ), patch( + "litellm.proxy.proxy_server.llm_router", None + ): + await managed_files.async_post_call_success_hook( + data={}, + user_api_key_dict=user_api_key_dict, + response=batch_response, + ) + + mock_afile_retrieve.assert_called() + call_kwargs = mock_afile_retrieve.call_args + assert call_kwargs.kwargs.get("custom_llm_provider") == "azure" + assert call_kwargs.kwargs.get("file_id") == "file-output-abc" diff --git a/tests/test_litellm/integrations/test_s3_v2.py b/tests/test_litellm/integrations/test_s3_v2.py index 0a3523699a9..51fec288f6e 100644 --- a/tests/test_litellm/integrations/test_s3_v2.py +++ b/tests/test_litellm/integrations/test_s3_v2.py @@ -157,6 +157,51 @@ def test_s3_v2_endpoint_url(self, mock_periodic_flush, mock_create_task): assert result == {"downloaded": "data"} +@pytest.mark.asyncio +async def test_async_log_event_skips_when_standard_logging_object_missing(): + """ + Reproduces the bug where _async_log_event_base raises ValueError when + kwargs has no standard_logging_object (e.g. call_type=afile_delete). + + The S3 logger should skip gracefully, not raise. + """ + logger = S3Logger( + s3_bucket_name="test-bucket", + s3_region_name="us-east-1", + s3_aws_access_key_id="fake", + s3_aws_secret_access_key="fake", + ) + + kwargs_without_slo = { + "call_type": "afile_delete", + "model": None, + "litellm_call_id": "test-call-id", + } + + start_time = datetime.utcnow() + end_time = datetime.utcnow() + + # Spy on handle_callback_failure — should NOT be called if we skip gracefully. + # Without the fix, the ValueError is caught by the except block which calls + # handle_callback_failure. With the fix, we return early and never hit except. + with patch.object(logger, "handle_callback_failure") as mock_failure: + await logger._async_log_event_base( + kwargs=kwargs_without_slo, + response_obj=None, + start_time=start_time, + end_time=end_time, + ) + + assert not mock_failure.called, ( + "handle_callback_failure should not be called — " + "missing standard_logging_object should be a graceful skip, not an error" + ) + + # Nothing should have been queued (catches the case where code falls + # through without returning and appends None to the queue) + assert len(logger.log_queue) == 0, "log_queue should be empty when standard_logging_object is missing" + + @pytest.mark.asyncio async def test_strip_base64_removes_file_and_nontext_entries(): logger = S3Logger(s3_strip_base64_files=True) diff --git a/tests/test_litellm/proxy/hooks/test_proxy_track_cost_callback.py b/tests/test_litellm/proxy/hooks/test_proxy_track_cost_callback.py index cb6d90103f7..e8765cf78ca 100644 --- a/tests/test_litellm/proxy/hooks/test_proxy_track_cost_callback.py +++ b/tests/test_litellm/proxy/hooks/test_proxy_track_cost_callback.py @@ -126,3 +126,77 @@ async def test_async_post_call_failure_hook_non_llm_route(): # Assert that update_database was NOT called for non-LLM routes mock_update_database.assert_not_called() + + +@pytest.mark.asyncio +async def test_track_cost_callback_skips_when_no_standard_logging_object(): + """ + Reproduces the bug where _PROXY_track_cost_callback raises + 'Cost tracking failed for model=None' when kwargs has no + standard_logging_object (e.g. call_type=afile_delete). + + File operations have no model and no standard_logging_object. + The callback should skip gracefully instead of raising. + """ + logger = _ProxyDBLogger() + + kwargs = { + "call_type": "afile_delete", + "model": None, + "litellm_call_id": "test-call-id", + "litellm_params": {}, + "stream": False, + } + + with patch( + "litellm.proxy.proxy_server.proxy_logging_obj", + ) as mock_proxy_logging: + mock_proxy_logging.failed_tracking_alert = AsyncMock() + mock_proxy_logging.db_spend_update_writer = MagicMock() + mock_proxy_logging.db_spend_update_writer.update_database = AsyncMock() + + await logger._PROXY_track_cost_callback( + kwargs=kwargs, + completion_response=None, + start_time=datetime.now(), + end_time=datetime.now(), + ) + + # update_database should NOT be called — nothing to track + mock_proxy_logging.db_spend_update_writer.update_database.assert_not_called() + + # failed_tracking_alert should NOT be called — this is not an error + mock_proxy_logging.failed_tracking_alert.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_value", [None, ""]) +async def test_track_cost_callback_skips_for_falsy_model_and_no_slo(model_value): + """ + Same bug as above but model can also be empty string (e.g. health check callbacks). + The guard should catch all falsy model values when sl_object is missing. + """ + logger = _ProxyDBLogger() + + kwargs = { + "call_type": "acompletion", + "model": model_value, + "litellm_params": {}, + "stream": False, + } + + with patch( + "litellm.proxy.proxy_server.proxy_logging_obj", + ) as mock_proxy_logging: + mock_proxy_logging.failed_tracking_alert = AsyncMock() + mock_proxy_logging.db_spend_update_writer = MagicMock() + mock_proxy_logging.db_spend_update_writer.update_database = AsyncMock() + + await logger._PROXY_track_cost_callback( + kwargs=kwargs, + completion_response=None, + start_time=datetime.now(), + end_time=datetime.now(), + ) + + mock_proxy_logging.failed_tracking_alert.assert_not_called() diff --git a/tests/test_litellm/test_service_logger.py b/tests/test_litellm/test_service_logger.py new file mode 100644 index 00000000000..ed44fe9b9f2 --- /dev/null +++ b/tests/test_litellm/test_service_logger.py @@ -0,0 +1,97 @@ +""" +Tests for litellm/_service_logger.py + +Regression test for KeyError: 'call_type' when async_log_success_event +is called without call_type in kwargs (e.g. from batch polling callbacks). +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, patch + +from litellm._service_logger import ServiceLogging + + +@pytest.mark.asyncio +async def test_async_log_success_event_should_not_raise_when_call_type_missing(): + """ + When async_log_success_event is called with kwargs that omit 'call_type', + it should not raise a KeyError. This happens in the batch polling flow + where check_batch_cost.py creates a Logging object whose model_call_details + don't include call_type. + """ + service_logger = ServiceLogging(mock_testing=True) + + start_time = datetime(2026, 2, 13, 22, 35, 0) + end_time = datetime(2026, 2, 13, 22, 35, 1) + kwargs_without_call_type = {"model": "gpt-4", "stream": False} + + with patch.object( + service_logger, "async_service_success_hook", new_callable=AsyncMock + ) as mock_hook: + await service_logger.async_log_success_event( + kwargs=kwargs_without_call_type, + response_obj=None, + start_time=start_time, + end_time=end_time, + ) + + mock_hook.assert_called_once() + call_kwargs = mock_hook.call_args + assert call_kwargs.kwargs["call_type"] == "unknown" + + +@pytest.mark.asyncio +async def test_async_log_success_event_should_pass_call_type_when_present(): + """ + When call_type IS present in kwargs, it should be forwarded correctly. + """ + service_logger = ServiceLogging(mock_testing=True) + + start_time = datetime(2026, 2, 13, 22, 35, 0) + end_time = datetime(2026, 2, 13, 22, 35, 1) + kwargs_with_call_type = { + "model": "gpt-4", + "stream": False, + "call_type": "aretrieve_batch", + } + + with patch.object( + service_logger, "async_service_success_hook", new_callable=AsyncMock + ) as mock_hook: + await service_logger.async_log_success_event( + kwargs=kwargs_with_call_type, + response_obj=None, + start_time=start_time, + end_time=end_time, + ) + + mock_hook.assert_called_once() + call_kwargs = mock_hook.call_args + assert call_kwargs.kwargs["call_type"] == "aretrieve_batch" + + +@pytest.mark.asyncio +async def test_async_log_success_event_should_handle_float_duration(): + """ + When start_time and end_time produce a float duration (not timedelta), + it should still work correctly. + """ + service_logger = ServiceLogging(mock_testing=True) + + start_time = 1000.0 + end_time = 1001.5 + + with patch.object( + service_logger, "async_service_success_hook", new_callable=AsyncMock + ) as mock_hook: + await service_logger.async_log_success_event( + kwargs={"call_type": "completion"}, + response_obj=None, + start_time=start_time, + end_time=end_time, + ) + + mock_hook.assert_called_once() + call_kwargs = mock_hook.call_args + assert call_kwargs.kwargs["duration"] == 1.5