diff --git a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py index dd0613f87ff..5ee3372cca7 100644 --- a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py +++ b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py @@ -369,6 +369,8 @@ async def async_pre_call_hook( # noqa: PLR0915 if ( call_type == CallTypes.afile_content.value or call_type == CallTypes.afile_delete.value + or call_type == CallTypes.afile_retrieve.value + or call_type == CallTypes.afile_content.value ): await self.check_managed_file_id_access(data, user_api_key_dict) @@ -433,12 +435,16 @@ async def async_pre_call_hook( # noqa: PLR0915 data["model_file_id_mapping"] = model_file_id_mapping elif ( call_type == CallTypes.aretrieve_batch.value + or call_type == CallTypes.acancel_batch.value or call_type == CallTypes.acancel_fine_tuning_job.value or call_type == CallTypes.aretrieve_fine_tuning_job.value ): accessor_key: Optional[str] = None retrieve_object_id: Optional[str] = None - if call_type == CallTypes.aretrieve_batch.value: + if ( + call_type == CallTypes.aretrieve_batch.value + or call_type == CallTypes.acancel_batch.value + ): accessor_key = "batch_id" elif ( call_type == CallTypes.acancel_fine_tuning_job.value @@ -454,6 +460,8 @@ async def async_pre_call_hook( # noqa: PLR0915 if retrieve_object_id else False ) + print(f"🔥potential_llm_object_id: {potential_llm_object_id}") + print(f"🔥retrieve_object_id: {retrieve_object_id}") if potential_llm_object_id and retrieve_object_id: ## VALIDATE USER HAS ACCESS TO THE OBJECT ## if not await self.can_user_call_unified_object_id( @@ -966,8 +974,10 @@ async def afile_delete( delete_response = None specific_model_file_id_mapping = model_file_id_mapping.get(file_id) if specific_model_file_id_mapping: + # Remove conflicting keys from data to avoid duplicate keyword arguments + filtered_data = {k: v for k, v in data.items() if k not in ("model", "file_id")} for model_id, model_file_id in specific_model_file_id_mapping.items(): - delete_response = await llm_router.afile_delete(model=model_id, file_id=model_file_id, **data) # type: ignore + delete_response = await llm_router.afile_delete(model=model_id, file_id=model_file_id, **filtered_data) # type: ignore stored_file_object = await self.delete_unified_file_id( file_id, litellm_parent_otel_span diff --git a/litellm/batches/main.py b/litellm/batches/main.py index 3367f567a7f..f7fcaed4979 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -31,7 +31,6 @@ from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import ( - Batch, CancelBatchRequest, CreateBatchRequest, RetrieveBatchRequest, @@ -868,7 +867,7 @@ async def acancel_batch( extra_headers: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None, **kwargs, -) -> Batch: +) -> LiteLLMBatch: """ Async: Cancels a batch. @@ -912,7 +911,7 @@ def cancel_batch( extra_headers: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None, **kwargs, -) -> Union[Batch, Coroutine[Any, Any, Batch]]: +) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: """ Cancels a batch. diff --git a/litellm/llms/azure/batches/handler.py b/litellm/llms/azure/batches/handler.py index 3996cb808e4..aaefe801687 100644 --- a/litellm/llms/azure/batches/handler.py +++ b/litellm/llms/azure/batches/handler.py @@ -5,12 +5,10 @@ from typing import Any, Coroutine, Optional, Union, cast import httpx - from openai import AsyncOpenAI, OpenAI from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI from litellm.types.llms.openai import ( - Batch, CancelBatchRequest, CreateBatchRequest, RetrieveBatchRequest, @@ -130,9 +128,9 @@ async def acancel_batch( self, cancel_batch_data: CancelBatchRequest, client: Union[AsyncAzureOpenAI, AsyncOpenAI], - ) -> Batch: + ) -> LiteLLMBatch: response = await client.batches.cancel(**cancel_batch_data) - return response + return LiteLLMBatch(**response.model_dump()) def cancel_batch( self, @@ -160,8 +158,23 @@ def cancel_batch( raise ValueError( "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." ) + + if _is_async is True: + if not isinstance(azure_client, (AsyncAzureOpenAI, AsyncOpenAI)): + raise ValueError( + "Azure client is not an instance of AsyncAzureOpenAI or AsyncOpenAI. Make sure you passed an async client." + ) + return self.acancel_batch( # type: ignore + cancel_batch_data=cancel_batch_data, client=azure_client + ) + + # At this point, azure_client is guaranteed to be a sync client + if not isinstance(azure_client, (AzureOpenAI, OpenAI)): + raise ValueError( + "Azure client is not an instance of AzureOpenAI or OpenAI. Make sure you passed a sync client." + ) response = azure_client.batches.cancel(**cancel_batch_data) - return response + return LiteLLMBatch(**response.model_dump()) async def alist_batches( self, diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index 4d623097478..8a8070240da 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -1923,10 +1923,10 @@ async def acancel_batch( self, cancel_batch_data: CancelBatchRequest, openai_client: AsyncOpenAI, - ) -> Batch: + ) -> LiteLLMBatch: verbose_logger.debug("async cancelling batch, args= %s", cancel_batch_data) response = await openai_client.batches.cancel(**cancel_batch_data) - return response + return LiteLLMBatch(**response.model_dump()) def cancel_batch( self, @@ -1962,8 +1962,13 @@ def cancel_batch( cancel_batch_data=cancel_batch_data, openai_client=openai_client ) + # At this point, openai_client is guaranteed to be a sync OpenAI client + if not isinstance(openai_client, OpenAI): + raise ValueError( + "OpenAI client is not an instance of OpenAI. Make sure you passed a sync OpenAI client." + ) response = openai_client.batches.cancel(**cancel_batch_data) - return response + return LiteLLMBatch(**response.model_dump()) async def alist_batches( self, diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index d9f824be222..9ef451bd70e 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -293,6 +293,8 @@ class LiteLLMRoutes(enum.Enum): "/batches", "/v1/batches/{batch_id}", "/batches/{batch_id}", + "/v1/batches/{batch_id}/cancel", + "/batches/{batch_id}/cancel", # files "/v1/files", "/files", diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 7e7c7c8c90c..748135ee2d2 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -1239,7 +1239,6 @@ async def user_api_key_auth( request_data=request_data, request=request ) route: str = get_request_route(request=request) - ## CHECK IF ROUTE IS ALLOWED user_api_key_auth_obj = await _user_api_key_auth_builder( @@ -1263,7 +1262,6 @@ async def user_api_key_auth( user_api_key_auth_obj.end_user_id = end_user_id user_api_key_auth_obj.request_route = normalize_request_route(route) - return user_api_key_auth_obj diff --git a/litellm/proxy/batches_endpoints/endpoints.py b/litellm/proxy/batches_endpoints/endpoints.py index 078e21f9bb4..220ecc5453c 100644 --- a/litellm/proxy/batches_endpoints/endpoints.py +++ b/litellm/proxy/batches_endpoints/endpoints.py @@ -24,9 +24,7 @@ _is_base64_encoded_unified_file_id, decode_model_from_file_id, encode_file_id_with_model, - get_batch_id_from_unified_batch_id, get_credentials_for_model, - get_model_id_from_unified_batch_id, get_models_from_unified_file_id, get_original_file_id, prepare_data_with_credentials, @@ -382,25 +380,6 @@ async def retrieve_batch( **data # type: ignore ) - # Re-encode all IDs in the response - if response: - if hasattr(response, "id") and response.id: - response.id = batch_id # Keep the encoded batch ID - - if hasattr(response, "input_file_id") and response.input_file_id: - response.input_file_id = encode_file_id_with_model( - file_id=response.input_file_id, model=model_from_id - ) - - if hasattr(response, "output_file_id") and response.output_file_id: - response.output_file_id = encode_file_id_with_model( - file_id=response.output_file_id, model=model_from_id - ) - - if hasattr(response, "error_file_id") and response.error_file_id: - response.error_file_id = encode_file_id_with_model( - file_id=response.error_file_id, model=model_from_id - ) verbose_proxy_logger.debug( f"Retrieved batch using model: {model_from_id}, original_id: {original_batch_id}" @@ -695,15 +674,31 @@ async def cancel_batch( data: Dict = {} try: - data = await _read_request_body(request=request) - verbose_proxy_logger.debug( - "Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)), - ) - # Check for encoded batch ID with model info model_from_id = decode_model_from_file_id(batch_id) + + # Create CancelBatchRequest with batch_id to enable ownership checking + _cancel_batch_request = CancelBatchRequest( + batch_id=batch_id, + ) + data = cast(dict, _cancel_batch_request) + unified_batch_id = _is_base64_encoded_unified_file_id(batch_id) + base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data) + ( + data, + litellm_logging_obj, + ) = await base_llm_response_processor.common_processing_pre_call_logic( + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_logging_obj=proxy_logging_obj, + proxy_config=proxy_config, + route_type="acancel_batch", + ) + # Include original request and headers in the data data = await add_litellm_data_to_request( data=data, @@ -751,17 +746,13 @@ async def cancel_batch( }, ) - model = ( - get_model_id_from_unified_batch_id(unified_batch_id) - if unified_batch_id - else None - ) - - model_batch_id = get_batch_id_from_unified_batch_id(unified_batch_id) - - data["batch_id"] = model_batch_id - - response = await llm_router.acancel_batch(model=model, **data) # type: ignore + # Hook has already extracted model and unwrapped batch_id into data dict + response = await llm_router.acancel_batch(**data) # type: ignore + response._hidden_params["unified_batch_id"] = unified_batch_id + + # Ensure model_id is set for the post_call_success_hook to re-encode IDs + if not response._hidden_params.get("model_id") and data.get("model"): + response._hidden_params["model_id"] = data["model"] # SCENARIO 3: Fallback to custom_llm_provider (uses env variables) else: @@ -775,6 +766,11 @@ async def cancel_batch( **_cancel_batch_data, ) + ### CALL HOOKS ### - modify outgoing data + response = await proxy_logging_obj.post_call_success_hook( + data=data, user_api_key_dict=user_api_key_dict, response=response + ) + ### ALERTING ### asyncio.create_task( proxy_logging_obj.update_request_status( diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 0d3e61b75c7..79392cb5fac 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -392,8 +392,10 @@ async def common_processing_pre_call_logic( "acreate_batch", "aretrieve_batch", "alist_batches", + "acancel_batch", "afile_content", "afile_retrieve", + "afile_delete", "atext_completion", "acreate_fine_tuning_job", "acancel_fine_tuning_job", @@ -606,6 +608,8 @@ async def base_process_llm_request( "aget_interaction", "adelete_interaction", "acancel_interaction", + "acancel_batch", + "afile_delete", ], proxy_logging_obj: ProxyLogging, general_settings: dict, diff --git a/litellm/proxy/openai_files_endpoints/files_endpoints.py b/litellm/proxy/openai_files_endpoints/files_endpoints.py index 2c6b378ae38..1c9909f2cc5 100644 --- a/litellm/proxy/openai_files_endpoints/files_endpoints.py +++ b/litellm/proxy/openai_files_endpoints/files_endpoints.py @@ -812,7 +812,7 @@ async def get_file( version, ) - data: Dict = {} + data: Dict = {"file_id": file_id} try: custom_llm_provider = ( @@ -992,7 +992,7 @@ async def delete_file( version, ) - data: Dict = {} + data: Dict = {"file_id": file_id} try: custom_llm_provider = ( provider @@ -1001,6 +1001,22 @@ async def delete_file( or await get_custom_llm_provider_from_request_body(request=request) or "openai" ) + + # Call common_processing_pre_call_logic to trigger permission checks + base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data) + ( + data, + litellm_logging_obj, + ) = await base_llm_response_processor.common_processing_pre_call_logic( + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_logging_obj=proxy_logging_obj, + proxy_config=proxy_config, + route_type="afile_delete", + ) + # Include original request and headers in the data data = await add_litellm_data_to_request( data=data, @@ -1060,11 +1076,13 @@ async def delete_file( code=500, ) + # Remove file_id from data to avoid duplicate keyword argument + data_without_file_id = {k: v for k, v in data.items() if k != "file_id"} response = await managed_files_obj.afile_delete( file_id=file_id, litellm_parent_otel_span=user_api_key_dict.parent_otel_span, llm_router=llm_router, - **data, + **data_without_file_id, ) else: response = await litellm.afile_delete( diff --git a/litellm/types/llms/anthropic.py b/litellm/types/llms/anthropic.py index 8d18322d374..9239287dcb7 100644 --- a/litellm/types/llms/anthropic.py +++ b/litellm/types/llms/anthropic.py @@ -296,9 +296,9 @@ class AnthropicMessagesDocumentParam(TypedDict, total=False): citations: Optional[CitationsObject] -class AnthropicMessagesToolResultContent(TypedDict): - type: Literal["text"] - text: str +class AnthropicMessagesToolResultContent(TypedDict, total=False): + type: Required[Literal["text"]] + text: Required[str] cache_control: Optional[Union[dict, ChatCompletionCachedContent]] diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 49c903502d1..f063418f92e 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -276,6 +276,8 @@ class CallTypes(str, Enum): acreate_batch = "acreate_batch" aretrieve_batch = "aretrieve_batch" retrieve_batch = "retrieve_batch" + acancel_batch = "acancel_batch" + cancel_batch = "cancel_batch" pass_through = "pass_through_endpoint" anthropic_messages = "anthropic_messages" get_assistants = "get_assistants" diff --git a/tests/enterprise/litellm_enterprise/proxy/hooks/test_managed_files.py b/tests/enterprise/litellm_enterprise/proxy/hooks/test_managed_files.py index 4fa16066e4b..3fd19cfa18f 100644 --- a/tests/enterprise/litellm_enterprise/proxy/hooks/test_managed_files.py +++ b/tests/enterprise/litellm_enterprise/proxy/hooks/test_managed_files.py @@ -235,10 +235,10 @@ async def test_async_pre_call_hook_for_unified_finetuning_job(): @pytest.mark.asyncio -@pytest.mark.parametrize("call_type", ["afile_content", "afile_delete"]) +@pytest.mark.parametrize("call_type", ["afile_content", "afile_delete", "afile_retrieve"]) async def test_can_user_call_unified_file_id(call_type): """ - Test that on file retrieve, delete we check if the user has access to the file + Test that on file retrieve, delete, and content we check if the user has access to the file """ from litellm.proxy._types import UserAPIKeyAuth @@ -376,10 +376,12 @@ async def test_output_file_id_for_batch_retrieve(): @pytest.mark.asyncio async def test_async_post_call_success_hook_twice_assert_no_unique_violation(): import asyncio - from litellm.types.utils import LiteLLMBatch - from litellm.proxy._types import UserAPIKeyAuth + from openai.types.batch import BatchRequestCounts + from litellm.proxy._types import UserAPIKeyAuth + from litellm.types.utils import LiteLLMBatch + # Use AsyncMock instead of real database connection prisma_client = AsyncMock() @@ -456,7 +458,7 @@ def test_update_responses_input_with_unified_file_id(): from litellm.litellm_core_utils.prompt_templates.common_utils import ( update_responses_input_with_model_file_ids, ) - + # Create a base64-encoded unified file ID # This decodes to: litellm_proxy:application/pdf;unified_id,6c0b5890-8914-48e0-b8f4-0ae5ed3c14a5;target_model_names,gpt-4o;llm_output_file_id,file-ECBPW7ML9g7XHdwGgUPZaM;llm_output_file_model_id,e26453f9e76e7993680d0068d98c1f4cc205bbad0967a33c664893568ca743c2 unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCw2YzBiNTg5MC04OTE0LTQ4ZTAtYjhmNC0wYWU1ZWQzYzE0YTU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1FQ0JQVzdNTDlnN1hIZHdHZ1VQWmFNO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxlMjY0NTNmOWU3NmU3OTkzNjgwZDAwNjhkOThjMWY0Y2MyMDViYmFkMDk2N2EzM2M2NjQ4OTM1NjhjYTc0M2My" @@ -496,7 +498,7 @@ def test_update_responses_input_with_regular_file_id(): from litellm.litellm_core_utils.prompt_templates.common_utils import ( update_responses_input_with_model_file_ids, ) - + # Regular OpenAI file ID (not a unified file ID) regular_file_id = "file-abc123xyz" @@ -549,7 +551,7 @@ def test_update_responses_input_with_multiple_file_ids(): from litellm.litellm_core_utils.prompt_templates.common_utils import ( update_responses_input_with_model_file_ids, ) - + # Unified file ID unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCw2YzBiNTg5MC04OTE0LTQ4ZTAtYjhmNC0wYWU1ZWQzYzE0YTU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1FQ0JQVzdNTDlnN1hIZHdHZ1VQWmFNO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxlMjY0NTNmOWU3NmU3OTkzNjgwZDAwNjhkOThjMWY0Y2MyMDViYmFkMDk2N2EzM2M2NjQ4OTM1NjhjYTc0M2My" # Regular OpenAI file ID @@ -831,9 +833,10 @@ async def test_afile_retrieve_raises_error_for_non_managed_file(): @pytest.mark.asyncio async def test_list_batches_from_managed_objects_table(): - from litellm.proxy._types import UserAPIKeyAuth from openai.types.batch import BatchRequestCounts + from litellm.proxy._types import UserAPIKeyAuth + prisma_client = AsyncMock() batch_record_1 = MagicMock() @@ -1085,4 +1088,378 @@ async def test_return_unified_file_id_includes_expires_at(): assert result.filename == "test.jsonl" assert result.bytes == 1234 assert result.created_at == 1234567890 - assert _is_base64_encoded_unified_file_id(result.id) \ No newline at end of file + assert _is_base64_encoded_unified_file_id(result.id) + + +# ============================================================================ +# Permission Tests - Cross-User Batch Access +# ============================================================================ +# These tests verify that batches and files created by one user +# cannot be accessed, modified, or cancelled by a different user. +# Reference: https://github.com/BerriAI/litellm/pull/17401/files + + +@pytest.mark.asyncio +async def test_user_b_cannot_retrieve_user_a_batch(): + """ + Test that User B cannot retrieve a batch created by User A. + + This verifies batch isolation between users at the database/hook level. + """ + from litellm.proxy._types import UserAPIKeyAuth + + prisma_client = AsyncMock() + + # Mock database to return User A as the creator + batch_record = MagicMock() + batch_record.created_by = "user_a_id" + prisma_client.db.litellm_managedobjecttable.find_first.return_value = batch_record + + proxy_managed_files = _PROXY_LiteLLMManagedFiles( + DualCache(), prisma_client=prisma_client + ) + + # User B tries to retrieve User A's batch + unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" + + with pytest.raises(HTTPException) as exc_info: + await proxy_managed_files.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth( + user_id="user_b_id", parent_otel_span=MagicMock() + ), + cache=MagicMock(), + data={"batch_id": unified_batch_id}, + call_type="aretrieve_batch", + ) + + # Should raise 403 Permission Denied + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_user_b_cannot_cancel_user_a_batch(): + """ + Test that User B cannot cancel a batch created by User A. + """ + from litellm.proxy._types import UserAPIKeyAuth + + prisma_client = AsyncMock() + + # Mock database to return User A as the creator + batch_record = MagicMock() + batch_record.created_by = "user_a_id" + prisma_client.db.litellm_managedobjecttable.find_first.return_value = batch_record + + proxy_managed_files = _PROXY_LiteLLMManagedFiles( + DualCache(), prisma_client=prisma_client + ) + + # User B tries to cancel User A's batch + unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" + + with pytest.raises(HTTPException) as exc_info: + await proxy_managed_files.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth( + user_id="user_b_id", parent_otel_span=MagicMock() + ), + cache=MagicMock(), + data={"batch_id": unified_batch_id}, + call_type="acancel_batch", + ) + + # Should raise 403 Permission Denied + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_user_a_can_retrieve_own_batch(): + """ + Test that User A can successfully retrieve their own batch. + + This is a positive test case to ensure permission checks don't block + legitimate access. + """ + from litellm.proxy._types import UserAPIKeyAuth + + prisma_client = AsyncMock() + + # Mock database to return User A as the creator + batch_record = MagicMock() + batch_record.created_by = "user_a_id" + prisma_client.db.litellm_managedobjecttable.find_first.return_value = batch_record + + proxy_managed_files = _PROXY_LiteLLMManagedFiles( + DualCache(), prisma_client=prisma_client + ) + + # User A retrieves their own batch + unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" + + # Should not raise an exception + result = await proxy_managed_files.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth( + user_id="user_a_id", parent_otel_span=MagicMock() + ), + cache=MagicMock(), + data={"batch_id": unified_batch_id}, + call_type="aretrieve_batch", + ) + + # Should successfully return the decoded batch_id + assert "batch_id" in result + assert result["model"] == "my-model" + + +@pytest.mark.asyncio +async def test_user_b_cannot_retrieve_user_a_file(): + """ + Test that User B cannot retrieve a file created by User A. + """ + from litellm.proxy._types import UserAPIKeyAuth + + prisma_client = AsyncMock() + + # Mock database to return User A as the creator + file_record = MagicMock() + file_record.created_by = "user_a_id" + prisma_client.db.litellm_managedfiletable.find_first.return_value = file_record + + proxy_managed_files = _PROXY_LiteLLMManagedFiles( + MagicMock(), prisma_client=prisma_client + ) + + # User B tries to retrieve User A's file + unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" + + with pytest.raises(HTTPException) as exc_info: + await proxy_managed_files.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth( + user_id="user_b_id", parent_otel_span=MagicMock() + ), + cache=MagicMock(), + data={"file_id": unified_file_id}, + call_type="afile_retrieve", + ) + + # Should raise 403 Permission Denied + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_user_b_cannot_download_user_a_file_content(): + """ + Test that User B cannot download file content for User A's file. + """ + from litellm.proxy._types import UserAPIKeyAuth + + prisma_client = AsyncMock() + + # Mock database to return User A as the creator + file_record = MagicMock() + file_record.created_by = "user_a_id" + prisma_client.db.litellm_managedfiletable.find_first.return_value = file_record + + proxy_managed_files = _PROXY_LiteLLMManagedFiles( + MagicMock(), prisma_client=prisma_client + ) + + # User B tries to download User A's file content + unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" + + with pytest.raises(HTTPException) as exc_info: + await proxy_managed_files.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth( + user_id="user_b_id", parent_otel_span=MagicMock() + ), + cache=MagicMock(), + data={"file_id": unified_file_id}, + call_type="afile_content", + ) + + # Should raise 403 Permission Denied + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_user_b_cannot_delete_user_a_file(): + """ + Test that User B cannot delete a file created by User A. + """ + from litellm.proxy._types import UserAPIKeyAuth + + prisma_client = AsyncMock() + + # Mock database to return User A as the creator + file_record = MagicMock() + file_record.created_by = "user_a_id" + prisma_client.db.litellm_managedfiletable.find_first.return_value = file_record + + proxy_managed_files = _PROXY_LiteLLMManagedFiles( + MagicMock(), prisma_client=prisma_client + ) + + # User B tries to delete User A's file + unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" + + with pytest.raises(HTTPException) as exc_info: + await proxy_managed_files.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth( + user_id="user_b_id", parent_otel_span=MagicMock() + ), + cache=MagicMock(), + data={"file_id": unified_file_id}, + call_type="afile_delete", + ) + + # Should raise 403 Permission Denied + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_user_a_can_retrieve_own_file(): + """ + Test that User A can successfully retrieve their own file. + + Positive test case to ensure permission checks work correctly for the owner. + """ + from litellm.proxy._types import UserAPIKeyAuth + + prisma_client = AsyncMock() + + # Mock database to return User A as the creator + file_record = MagicMock() + file_record.created_by = "user_a_id" + file_record.model_mappings = '{"model-123": "file-abc123"}' + file_record.file_object = json.dumps({ + "id": "file-abc123", + "object": "file", + "bytes": 1234, + "created_at": 1234567890, + "filename": "test.jsonl", + "purpose": "batch", + }) + prisma_client.db.litellm_managedfiletable.find_first.return_value = file_record + + proxy_managed_files = _PROXY_LiteLLMManagedFiles( + MagicMock(), prisma_client=prisma_client + ) + + # User A retrieves their own file + unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9qc29uO3VuaWZpZWRfaWQsZmlsZS1hYmMxMjM" + + # Should not raise an exception + result = await proxy_managed_files.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth( + user_id="user_a_id", parent_otel_span=MagicMock() + ), + cache=MagicMock(), + data={"file_id": unified_file_id}, + call_type="afile_retrieve", + ) + + # Should successfully return the decoded file_id + assert "file_id" in result + + +@pytest.mark.asyncio +async def test_list_batches_only_returns_user_own_batches(): + """ + Test that list_user_batches only returns batches created by the requesting user. + + This ensures users cannot see other users' batches in list operations. + """ + from litellm.proxy._types import UserAPIKeyAuth + + prisma_client = AsyncMock() + + # Create batches for User A + batch_user_a = MagicMock() + batch_user_a.unified_object_id = "batch-user-a" + batch_user_a.file_object = json.dumps({ + "id": "batch_a", + "object": "batch", + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + "status": "completed", + "created_at": 1234567890, + "input_file_id": "file-a", + "request_counts": {"total": 1, "completed": 1, "failed": 0}, + }) + + # Mock database to only return User A's batches + prisma_client.db.litellm_managedobjecttable.find_many.return_value = [batch_user_a] + + proxy_managed_files = _PROXY_LiteLLMManagedFiles( + DualCache(), prisma_client=prisma_client + ) + + # User A requests their batches + result = await proxy_managed_files.list_user_batches( + user_api_key_dict=UserAPIKeyAuth(user_id="user_a_id"), + limit=10, + ) + + # Should only return User A's batches + assert len(result["data"]) == 1 + assert result["data"][0].id == "batch-user-a" + + # Verify the database query filtered by user_id + prisma_client.db.litellm_managedobjecttable.find_many.assert_called_once_with( + where={"file_purpose": "batch", "created_by": "user_a_id"}, + take=10, + order={"created_at": "desc"}, + ) + + +@pytest.mark.asyncio +async def test_same_user_different_keys_can_access_batch(): + """ + Test that different API keys for the same user can access the same batch. + + This verifies that permission checks are based on user_id, not API key, + allowing users to have multiple keys that can all access their resources. + """ + from litellm.proxy._types import UserAPIKeyAuth + + prisma_client = AsyncMock() + + # Mock database to return the user_id as creator + batch_record = MagicMock() + batch_record.created_by = "user_a_id" + prisma_client.db.litellm_managedobjecttable.find_first.return_value = batch_record + + proxy_managed_files = _PROXY_LiteLLMManagedFiles( + DualCache(), prisma_client=prisma_client + ) + + unified_batch_id = "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1tb2RlbDtsbG1fYmF0Y2hfaWQ6YmF0Y2hfYWJjMTIz" + + # First API key for User A retrieves the batch + result1 = await proxy_managed_files.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth( + user_id="user_a_id", + api_key="key-1", + parent_otel_span=MagicMock() + ), + cache=MagicMock(), + data={"batch_id": unified_batch_id}, + call_type="aretrieve_batch", + ) + + assert "batch_id" in result1 + + # Second API key for the same User A retrieves the batch + result2 = await proxy_managed_files.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth( + user_id="user_a_id", + api_key="key-2", + parent_otel_span=MagicMock() + ), + cache=MagicMock(), + data={"batch_id": unified_batch_id}, + call_type="aretrieve_batch", + ) + + assert "batch_id" in result2 + # Both keys should get the same result + assert result1["batch_id"] == result2["batch_id"]