Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions litellm-proxy-extras/litellm_proxy_extras/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,11 @@ model LiteLLM_ManagedVectorStoresTable {
updated_at DateTime @updatedAt
litellm_credential_name String?
litellm_params Json?
team_id String?
user_id String?

@@index([team_id])
@@index([user_id])
}

// Guardrails table for storing guardrail configurations
Expand Down
5 changes: 5 additions & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,9 @@ class LiteLLMRoutes(enum.Enum):
"/v1/vector_stores/{vector_store_id}/files/{file_id}",
"/vector_stores/{vector_store_id}/files/{file_id}/content",
"/v1/vector_stores/{vector_store_id}/files/{file_id}/content",
"/vector_store/list",
"/v1/vector_store/list",

# search
"/search",
"/v1/search",
Expand Down Expand Up @@ -3917,6 +3920,8 @@ class LiteLLM_ManagedVectorStoresTable(LiteLLMPydanticObjectBase):
updated_at: Optional[datetime]
litellm_credential_name: Optional[str]
litellm_params: Optional[Dict[str, Any]]
team_id: Optional[str]
user_id: Optional[str]


class ResponseLiteLLM_ManagedVectorStore(TypedDict, total=False):
Expand Down
3 changes: 3 additions & 0 deletions litellm/proxy/rag_endpoints/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ async def _save_vector_store_to_db_from_rag_ingest(
- Checks if the vector store already exists in the database
- Creates a new database entry if it doesn't exist
- Adds the vector store to the registry
- Tracks team_id and user_id for access control

Args:
response: The response from litellm.aingest()
Expand Down Expand Up @@ -176,6 +177,8 @@ async def _save_vector_store_to_db_from_rag_ingest(
vector_store_description=vector_store_description,
vector_store_metadata=initial_metadata,
litellm_params=provider_specific_params if provider_specific_params else None,
team_id=user_api_key_dict.team_id,
user_id=user_api_key_dict.user_id,
)

verbose_proxy_logger.info(
Expand Down
5 changes: 5 additions & 0 deletions litellm/proxy/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,11 @@ model LiteLLM_ManagedVectorStoresTable {
updated_at DateTime @updatedAt
litellm_credential_name String?
litellm_params Json?
team_id String?
user_id String?

@@index([team_id])
@@index([user_id])
}

// Guardrails table for storing guardrail configurations
Expand Down
53 changes: 51 additions & 2 deletions litellm/proxy/vector_store_endpoints/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,54 @@
########################################################


def _check_vector_store_access(
vector_store: LiteLLM_ManagedVectorStore,
user_api_key_dict: UserAPIKeyAuth,
) -> bool:
"""
Check if the user has access to the vector store based on team membership.

Args:
vector_store: The vector store to check access for
user_api_key_dict: User API key authentication info

Returns:
True if user has access, False otherwise

Access rules:
- If vector store has no team_id, it's accessible to all (legacy behavior)
- If user's team_id matches the vector store's team_id, access is granted
- Otherwise, access is denied
"""
vector_store_team_id = vector_store.get("team_id")

# If vector store has no team_id, it's accessible to all (legacy behavior)
if vector_store_team_id is None:
return True

# Check if user's team matches the vector store's team
user_team_id = user_api_key_dict.team_id
if user_team_id == vector_store_team_id:
return True

return False


def _update_request_data_with_litellm_managed_vector_store_registry(
data: Dict,
vector_store_id: str,
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
) -> Dict:
"""
Update the request data with the litellm managed vector store registry.


Args:
data: Request data to update
vector_store_id: ID of the vector store
user_api_key_dict: User API key authentication info for access control

Raises:
HTTPException: If user doesn't have access to the vector store
"""
if litellm.vector_store_registry is not None:
vector_store_to_run: Optional[LiteLLM_ManagedVectorStore] = (
Expand All @@ -33,6 +74,14 @@ def _update_request_data_with_litellm_managed_vector_store_registry(
)
)
if vector_store_to_run is not None:
# Check access control if user_api_key_dict is provided
if user_api_key_dict is not None:
if not _check_vector_store_access(vector_store_to_run, user_api_key_dict):
raise HTTPException(
status_code=403,
detail="Access denied: You do not have permission to access this vector store",
)

if "custom_llm_provider" in vector_store_to_run:
data["custom_llm_provider"] = vector_store_to_run.get(
"custom_llm_provider"
Expand Down Expand Up @@ -88,7 +137,7 @@ async def vector_store_search(
data["vector_store_id"] = vector_store_id

data = _update_request_data_with_litellm_managed_vector_store_registry(
data=data, vector_store_id=vector_store_id
data=data, vector_store_id=vector_store_id, user_api_key_dict=user_api_key_dict
)

processor = ProxyBaseLLMRequestProcessing(data=data)
Expand Down
86 changes: 82 additions & 4 deletions litellm/proxy/vector_store_endpoints/management_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,39 @@ async def _resolve_embedding_config_from_db(
########################################################
# Helper Functions
########################################################
def _check_vector_store_access(
vector_store: LiteLLM_ManagedVectorStore,
user_api_key_dict: UserAPIKeyAuth,
) -> bool:
"""
Check if the user has access to the vector store based on team membership.

Args:
vector_store: The vector store to check access for
user_api_key_dict: User API key authentication info

Returns:
True if user has access, False otherwise

Access rules:
- If vector store has no team_id, it's accessible to all (legacy behavior)
- If user's team_id matches the vector store's team_id, access is granted
- Otherwise, access is denied
"""
vector_store_team_id = vector_store.get("team_id")

# If vector store has no team_id, it's accessible to all (legacy behavior)
if vector_store_team_id is None:
return True

# Check if user's team matches the vector store's team
user_team_id = user_api_key_dict.team_id
if user_team_id == vector_store_team_id:
return True

return False


async def create_vector_store_in_db(
vector_store_id: str,
custom_llm_provider: str,
Expand All @@ -145,6 +178,8 @@ async def create_vector_store_in_db(
vector_store_metadata: Optional[Dict] = None,
litellm_params: Optional[Dict] = None,
litellm_credential_name: Optional[str] = None,
team_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> LiteLLM_ManagedVectorStore:
"""
Helper function to create a vector store in the database.
Expand Down Expand Up @@ -191,6 +226,10 @@ async def create_vector_store_in_db(
data_to_create["vector_store_metadata"] = safe_dumps(vector_store_metadata)
if litellm_credential_name is not None:
data_to_create["litellm_credential_name"] = litellm_credential_name
if team_id is not None:
data_to_create["team_id"] = team_id
if user_id is not None:
data_to_create["user_id"] = user_id

# Handle litellm_params - always provide at least an empty dict
if litellm_params:
Expand Down Expand Up @@ -288,6 +327,8 @@ async def new_vector_store(
vector_store_metadata=validated_metadata,
litellm_params=vector_store.get("litellm_params"),
litellm_credential_name=vector_store.get("litellm_credential_name"),
team_id=user_api_key_dict.team_id,
user_id=user_api_key_dict.user_id,
)

return {
Expand Down Expand Up @@ -380,14 +421,19 @@ async def list_vector_stores(
updated_data=vector_store
)

combined_vector_stores = list(vector_store_map.values())
total_count = len(combined_vector_stores)
# Filter vector stores based on team access
accessible_vector_stores = [
vs for vs in vector_store_map.values()
if _check_vector_store_access(vs, user_api_key_dict)
]

total_count = len(accessible_vector_stores)
total_pages = (total_count + page_size - 1) // page_size

# Format response using LiteLLM_ManagedVectorStoreListResponse
response = LiteLLM_ManagedVectorStoreListResponse(
object="list",
data=combined_vector_stores,
data=accessible_vector_stores,
total_count=total_count,
current_page=page,
total_pages=total_pages,
Expand Down Expand Up @@ -423,6 +469,7 @@ async def delete_vector_store(
# Check if vector store exists in database or in-memory registry
db_vector_store_exists = False
memory_vector_store_exists = False
vector_store_to_check = None

existing_vector_store = (
await prisma_client.db.litellm_managedvectorstorestable.find_unique(
Expand All @@ -431,6 +478,9 @@ async def delete_vector_store(
)
if existing_vector_store is not None:
db_vector_store_exists = True
vector_store_to_check = LiteLLM_ManagedVectorStore(
**existing_vector_store.model_dump()
)

# Check in-memory registry
if litellm.vector_store_registry is not None:
Expand All @@ -439,13 +489,24 @@ async def delete_vector_store(
)
if memory_vector_store is not None:
memory_vector_store_exists = True
if vector_store_to_check is None:
vector_store_to_check = memory_vector_store

# If not found in either location, raise 404
if not db_vector_store_exists and not memory_vector_store_exists:
raise HTTPException(
status_code=404,
detail=f"Vector store with ID {data.vector_store_id} not found",
)

# Check access control
if vector_store_to_check and not _check_vector_store_access(
vector_store_to_check, user_api_key_dict
):
raise HTTPException(
status_code=403,
detail="Access denied: You do not have permission to delete this vector store",
)

# Delete from database if exists
if db_vector_store_exists:
Expand Down Expand Up @@ -492,6 +553,13 @@ async def get_vector_store_info(
vector_store_id=data.vector_store_id
)
if vector_store is not None:
# Check access control
if not _check_vector_store_access(vector_store, user_api_key_dict):
raise HTTPException(
status_code=403,
detail="Access denied: You do not have permission to access this vector store",
)

vector_store_metadata = vector_store.get("vector_store_metadata")
# Parse metadata if it's a JSON string
parsed_metadata: Optional[dict] = None
Expand All @@ -513,6 +581,8 @@ async def get_vector_store_info(
updated_at=vector_store.get("updated_at") or None,
litellm_credential_name=vector_store.get("litellm_credential_name"),
litellm_params=vector_store.get("litellm_params") or None,
team_id=vector_store.get("team_id") or None,
user_id=vector_store.get("user_id") or None,
)
return {"vector_store": vector_store_pydantic_obj}

Expand All @@ -526,8 +596,16 @@ async def get_vector_store_info(
status_code=404,
detail=f"Vector store with ID {data.vector_store_id} not found",
)


# Check access control for DB vector store
vector_store_dict = vector_store.model_dump() # type: ignore[attr-defined]
vector_store_typed = LiteLLM_ManagedVectorStore(**vector_store_dict)
if not _check_vector_store_access(vector_store_typed, user_api_key_dict):
raise HTTPException(
status_code=403,
detail="Access denied: You do not have permission to access this vector store",
)

return {"vector_store": vector_store_dict}
except Exception as e:
verbose_proxy_logger.exception(f"Error getting vector store info: {str(e)}")
Expand Down
4 changes: 4 additions & 0 deletions litellm/types/vector_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class LiteLLM_ManagedVectorStore(TypedDict, total=False):

# litellm_params
litellm_params: Optional[Dict[str, Any]]

# access control fields
team_id: Optional[str]
user_id: Optional[str]


class LiteLLM_ManagedVectorStoreListResponse(TypedDict, total=False):
Expand Down
5 changes: 5 additions & 0 deletions schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,11 @@ model LiteLLM_ManagedVectorStoresTable {
updated_at DateTime @updatedAt
litellm_credential_name String?
litellm_params Json?
team_id String?
user_id String?

@@index([team_id])
@@index([user_id])
}

// Guardrails table for storing guardrail configurations
Expand Down
Loading
Loading