diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index d7aa6e9f0d0..ca60b9e1bec 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -760,6 +760,11 @@ model LiteLLM_ManagedVectorStoresTable { updated_at DateTime @updatedAt litellm_credential_name String? litellm_params Json? + team_id String? + user_id String? + + @@index([team_id]) + @@index([user_id]) } // Guardrails table for storing guardrail configurations diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index e15179a8e29..d9f824be222 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -353,6 +353,9 @@ class LiteLLMRoutes(enum.Enum): "/v1/vector_stores/{vector_store_id}/files/{file_id}", "/vector_stores/{vector_store_id}/files/{file_id}/content", "/v1/vector_stores/{vector_store_id}/files/{file_id}/content", + "/vector_store/list", + "/v1/vector_store/list", + # search "/search", "/v1/search", @@ -3917,6 +3920,8 @@ class LiteLLM_ManagedVectorStoresTable(LiteLLMPydanticObjectBase): updated_at: Optional[datetime] litellm_credential_name: Optional[str] litellm_params: Optional[Dict[str, Any]] + team_id: Optional[str] + user_id: Optional[str] class ResponseLiteLLM_ManagedVectorStore(TypedDict, total=False): diff --git a/litellm/proxy/rag_endpoints/endpoints.py b/litellm/proxy/rag_endpoints/endpoints.py index 2c34457c02e..39909df8e9f 100644 --- a/litellm/proxy/rag_endpoints/endpoints.py +++ b/litellm/proxy/rag_endpoints/endpoints.py @@ -94,6 +94,7 @@ async def _save_vector_store_to_db_from_rag_ingest( - Checks if the vector store already exists in the database - Creates a new database entry if it doesn't exist - Adds the vector store to the registry + - Tracks team_id and user_id for access control Args: response: The response from litellm.aingest() @@ -176,6 +177,8 @@ async def _save_vector_store_to_db_from_rag_ingest( vector_store_description=vector_store_description, vector_store_metadata=initial_metadata, litellm_params=provider_specific_params if provider_specific_params else None, + team_id=user_api_key_dict.team_id, + user_id=user_api_key_dict.user_id, ) verbose_proxy_logger.info( diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index d7aa6e9f0d0..ca60b9e1bec 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -760,6 +760,11 @@ model LiteLLM_ManagedVectorStoresTable { updated_at DateTime @updatedAt litellm_credential_name String? litellm_params Json? + team_id String? + user_id String? + + @@index([team_id]) + @@index([user_id]) } // Guardrails table for storing guardrail configurations diff --git a/litellm/proxy/vector_store_endpoints/endpoints.py b/litellm/proxy/vector_store_endpoints/endpoints.py index 6e22d66c35d..9ba12537bc8 100644 --- a/litellm/proxy/vector_store_endpoints/endpoints.py +++ b/litellm/proxy/vector_store_endpoints/endpoints.py @@ -18,13 +18,54 @@ ######################################################## +def _check_vector_store_access( + vector_store: LiteLLM_ManagedVectorStore, + user_api_key_dict: UserAPIKeyAuth, +) -> bool: + """ + Check if the user has access to the vector store based on team membership. + + Args: + vector_store: The vector store to check access for + user_api_key_dict: User API key authentication info + + Returns: + True if user has access, False otherwise + + Access rules: + - If vector store has no team_id, it's accessible to all (legacy behavior) + - If user's team_id matches the vector store's team_id, access is granted + - Otherwise, access is denied + """ + vector_store_team_id = vector_store.get("team_id") + + # If vector store has no team_id, it's accessible to all (legacy behavior) + if vector_store_team_id is None: + return True + + # Check if user's team matches the vector store's team + user_team_id = user_api_key_dict.team_id + if user_team_id == vector_store_team_id: + return True + + return False + + def _update_request_data_with_litellm_managed_vector_store_registry( data: Dict, vector_store_id: str, + user_api_key_dict: Optional[UserAPIKeyAuth] = None, ) -> Dict: """ Update the request data with the litellm managed vector store registry. - + + Args: + data: Request data to update + vector_store_id: ID of the vector store + user_api_key_dict: User API key authentication info for access control + + Raises: + HTTPException: If user doesn't have access to the vector store """ if litellm.vector_store_registry is not None: vector_store_to_run: Optional[LiteLLM_ManagedVectorStore] = ( @@ -33,6 +74,14 @@ def _update_request_data_with_litellm_managed_vector_store_registry( ) ) if vector_store_to_run is not None: + # Check access control if user_api_key_dict is provided + if user_api_key_dict is not None: + if not _check_vector_store_access(vector_store_to_run, user_api_key_dict): + raise HTTPException( + status_code=403, + detail="Access denied: You do not have permission to access this vector store", + ) + if "custom_llm_provider" in vector_store_to_run: data["custom_llm_provider"] = vector_store_to_run.get( "custom_llm_provider" @@ -88,7 +137,7 @@ async def vector_store_search( data["vector_store_id"] = vector_store_id data = _update_request_data_with_litellm_managed_vector_store_registry( - data=data, vector_store_id=vector_store_id + data=data, vector_store_id=vector_store_id, user_api_key_dict=user_api_key_dict ) processor = ProxyBaseLLMRequestProcessing(data=data) diff --git a/litellm/proxy/vector_store_endpoints/management_endpoints.py b/litellm/proxy/vector_store_endpoints/management_endpoints.py index f3787e62f4d..6185f1541fc 100644 --- a/litellm/proxy/vector_store_endpoints/management_endpoints.py +++ b/litellm/proxy/vector_store_endpoints/management_endpoints.py @@ -136,6 +136,39 @@ async def _resolve_embedding_config_from_db( ######################################################## # Helper Functions ######################################################## +def _check_vector_store_access( + vector_store: LiteLLM_ManagedVectorStore, + user_api_key_dict: UserAPIKeyAuth, +) -> bool: + """ + Check if the user has access to the vector store based on team membership. + + Args: + vector_store: The vector store to check access for + user_api_key_dict: User API key authentication info + + Returns: + True if user has access, False otherwise + + Access rules: + - If vector store has no team_id, it's accessible to all (legacy behavior) + - If user's team_id matches the vector store's team_id, access is granted + - Otherwise, access is denied + """ + vector_store_team_id = vector_store.get("team_id") + + # If vector store has no team_id, it's accessible to all (legacy behavior) + if vector_store_team_id is None: + return True + + # Check if user's team matches the vector store's team + user_team_id = user_api_key_dict.team_id + if user_team_id == vector_store_team_id: + return True + + return False + + async def create_vector_store_in_db( vector_store_id: str, custom_llm_provider: str, @@ -145,6 +178,8 @@ async def create_vector_store_in_db( vector_store_metadata: Optional[Dict] = None, litellm_params: Optional[Dict] = None, litellm_credential_name: Optional[str] = None, + team_id: Optional[str] = None, + user_id: Optional[str] = None, ) -> LiteLLM_ManagedVectorStore: """ Helper function to create a vector store in the database. @@ -191,6 +226,10 @@ async def create_vector_store_in_db( data_to_create["vector_store_metadata"] = safe_dumps(vector_store_metadata) if litellm_credential_name is not None: data_to_create["litellm_credential_name"] = litellm_credential_name + if team_id is not None: + data_to_create["team_id"] = team_id + if user_id is not None: + data_to_create["user_id"] = user_id # Handle litellm_params - always provide at least an empty dict if litellm_params: @@ -288,6 +327,8 @@ async def new_vector_store( vector_store_metadata=validated_metadata, litellm_params=vector_store.get("litellm_params"), litellm_credential_name=vector_store.get("litellm_credential_name"), + team_id=user_api_key_dict.team_id, + user_id=user_api_key_dict.user_id, ) return { @@ -380,14 +421,19 @@ async def list_vector_stores( updated_data=vector_store ) - combined_vector_stores = list(vector_store_map.values()) - total_count = len(combined_vector_stores) + # Filter vector stores based on team access + accessible_vector_stores = [ + vs for vs in vector_store_map.values() + if _check_vector_store_access(vs, user_api_key_dict) + ] + + total_count = len(accessible_vector_stores) total_pages = (total_count + page_size - 1) // page_size # Format response using LiteLLM_ManagedVectorStoreListResponse response = LiteLLM_ManagedVectorStoreListResponse( object="list", - data=combined_vector_stores, + data=accessible_vector_stores, total_count=total_count, current_page=page, total_pages=total_pages, @@ -423,6 +469,7 @@ async def delete_vector_store( # Check if vector store exists in database or in-memory registry db_vector_store_exists = False memory_vector_store_exists = False + vector_store_to_check = None existing_vector_store = ( await prisma_client.db.litellm_managedvectorstorestable.find_unique( @@ -431,6 +478,9 @@ async def delete_vector_store( ) if existing_vector_store is not None: db_vector_store_exists = True + vector_store_to_check = LiteLLM_ManagedVectorStore( + **existing_vector_store.model_dump() + ) # Check in-memory registry if litellm.vector_store_registry is not None: @@ -439,6 +489,8 @@ async def delete_vector_store( ) if memory_vector_store is not None: memory_vector_store_exists = True + if vector_store_to_check is None: + vector_store_to_check = memory_vector_store # If not found in either location, raise 404 if not db_vector_store_exists and not memory_vector_store_exists: @@ -446,6 +498,15 @@ async def delete_vector_store( status_code=404, detail=f"Vector store with ID {data.vector_store_id} not found", ) + + # Check access control + if vector_store_to_check and not _check_vector_store_access( + vector_store_to_check, user_api_key_dict + ): + raise HTTPException( + status_code=403, + detail="Access denied: You do not have permission to delete this vector store", + ) # Delete from database if exists if db_vector_store_exists: @@ -492,6 +553,13 @@ async def get_vector_store_info( vector_store_id=data.vector_store_id ) if vector_store is not None: + # Check access control + if not _check_vector_store_access(vector_store, user_api_key_dict): + raise HTTPException( + status_code=403, + detail="Access denied: You do not have permission to access this vector store", + ) + vector_store_metadata = vector_store.get("vector_store_metadata") # Parse metadata if it's a JSON string parsed_metadata: Optional[dict] = None @@ -513,6 +581,8 @@ async def get_vector_store_info( updated_at=vector_store.get("updated_at") or None, litellm_credential_name=vector_store.get("litellm_credential_name"), litellm_params=vector_store.get("litellm_params") or None, + team_id=vector_store.get("team_id") or None, + user_id=vector_store.get("user_id") or None, ) return {"vector_store": vector_store_pydantic_obj} @@ -526,8 +596,16 @@ async def get_vector_store_info( status_code=404, detail=f"Vector store with ID {data.vector_store_id} not found", ) - + + # Check access control for DB vector store vector_store_dict = vector_store.model_dump() # type: ignore[attr-defined] + vector_store_typed = LiteLLM_ManagedVectorStore(**vector_store_dict) + if not _check_vector_store_access(vector_store_typed, user_api_key_dict): + raise HTTPException( + status_code=403, + detail="Access denied: You do not have permission to access this vector store", + ) + return {"vector_store": vector_store_dict} except Exception as e: verbose_proxy_logger.exception(f"Error getting vector store info: {str(e)}") diff --git a/litellm/types/vector_stores.py b/litellm/types/vector_stores.py index a4ceb2c9ac7..58a6a3272cd 100644 --- a/litellm/types/vector_stores.py +++ b/litellm/types/vector_stores.py @@ -38,6 +38,10 @@ class LiteLLM_ManagedVectorStore(TypedDict, total=False): # litellm_params litellm_params: Optional[Dict[str, Any]] + + # access control fields + team_id: Optional[str] + user_id: Optional[str] class LiteLLM_ManagedVectorStoreListResponse(TypedDict, total=False): diff --git a/schema.prisma b/schema.prisma index d7aa6e9f0d0..ca60b9e1bec 100644 --- a/schema.prisma +++ b/schema.prisma @@ -760,6 +760,11 @@ model LiteLLM_ManagedVectorStoresTable { updated_at DateTime @updatedAt litellm_credential_name String? litellm_params Json? + team_id String? + user_id String? + + @@index([team_id]) + @@index([user_id]) } // Guardrails table for storing guardrail configurations diff --git a/tests/test_litellm/proxy/vector_store_endpoints/test_vector_store_access_control.py b/tests/test_litellm/proxy/vector_store_endpoints/test_vector_store_access_control.py new file mode 100644 index 00000000000..42043c6d168 --- /dev/null +++ b/tests/test_litellm/proxy/vector_store_endpoints/test_vector_store_access_control.py @@ -0,0 +1,87 @@ +""" +Test vector store access control based on team membership. + +Core tests: +1. Access control logic works correctly for different team scenarios +2. Delete endpoint enforces team access control +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException + +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.vector_store_endpoints.management_endpoints import ( + _check_vector_store_access, +) +from litellm.types.vector_stores import LiteLLM_ManagedVectorStore + + +def test_check_vector_store_access(): + """Test core access control logic for team-based vector store access""" + + # Test 1: Legacy vector stores (no team_id) are accessible to all + vector_store: LiteLLM_ManagedVectorStore = { + "vector_store_id": "vs_legacy", + "custom_llm_provider": "openai", + "team_id": None, + } + user = UserAPIKeyAuth(team_id="team_456") + assert _check_vector_store_access(vector_store, user) is True + + # Test 2: User can access their team's vector stores + vector_store = { + "vector_store_id": "vs_team", + "custom_llm_provider": "openai", + "team_id": "team_456", + } + user = UserAPIKeyAuth(team_id="team_456") + assert _check_vector_store_access(vector_store, user) is True + + # Test 3: User cannot access other teams' vector stores + vector_store = { + "vector_store_id": "vs_team", + "custom_llm_provider": "openai", + "team_id": "team_456", + } + user = UserAPIKeyAuth(team_id="team_789") + assert _check_vector_store_access(vector_store, user) is False + + +@pytest.mark.asyncio +async def test_delete_vector_store_checks_access(): + """Test that delete endpoint enforces team access control""" + from litellm.proxy.vector_store_endpoints.management_endpoints import ( + delete_vector_store, + ) + from litellm.types.vector_stores import VectorStoreDeleteRequest + + mock_prisma = MagicMock() + mock_vector_store = MagicMock( + model_dump=lambda: { + "vector_store_id": "vs_123", + "custom_llm_provider": "openai", + "team_id": "team_456", + } + ) + mock_prisma.db.litellm_managedvectorstorestable.find_unique = AsyncMock( + return_value=mock_vector_store + ) + + # User from different team should get 403 + user_api_key_dict = UserAPIKeyAuth(team_id="team_789") + request = VectorStoreDeleteRequest(vector_store_id="vs_123") + + with patch( + "litellm.proxy.vector_store_endpoints.management_endpoints.prisma_client", + mock_prisma, + ): + with patch("litellm.vector_store_registry", None): + with pytest.raises(HTTPException) as exc_info: + await delete_vector_store( + data=request, user_api_key_dict=user_api_key_dict + ) + + assert exc_info.value.status_code == 403 + assert "Access denied" in exc_info.value.detail