Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ class KeyManagementRoutes(str, enum.Enum):
KEY_BLOCK = "/key/block"
KEY_UNBLOCK = "/key/unblock"
KEY_BULK_UPDATE = "/key/bulk_update"
KEY_RESET_SPEND = "/key/{key_id}/reset_spend"

# info and health routes
KEY_INFO = "/key/info"
Expand Down Expand Up @@ -987,6 +988,10 @@ class RegenerateKeyRequest(GenerateKeyRequest):
new_master_key: Optional[str] = None


class ResetSpendRequest(LiteLLMPydanticObjectBase):
reset_to: float


class KeyRequest(LiteLLMPydanticObjectBase):
keys: Optional[List[str]] = None
key_aliases: Optional[List[str]] = None
Expand Down
157 changes: 157 additions & 0 deletions litellm/proxy/management_endpoints/key_management_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3373,6 +3373,163 @@ async def regenerate_key_fn(
raise handle_exception_on_proxy(e)


async def _check_proxy_or_team_admin_for_key(
key_in_db: LiteLLM_VerificationToken,
user_api_key_dict: UserAPIKeyAuth,
prisma_client: PrismaClient,
user_api_key_cache: DualCache,
) -> None:
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
return

if key_in_db.team_id is not None:
team_table = await get_team_object(
team_id=key_in_db.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
check_db_only=True,
)
if team_table is not None:
if _is_user_team_admin(
user_api_key_dict=user_api_key_dict,
team_obj=team_table,
):
return

raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={"error": "You must be a proxy admin or team admin to reset key spend"},
)


def _validate_reset_spend_value(
reset_to: Any, key_in_db: LiteLLM_VerificationToken
) -> float:
if not isinstance(reset_to, (int, float)):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "reset_to must be a float"},
)

reset_to = float(reset_to)

if reset_to < 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "reset_to must be >= 0"},
)

current_spend = key_in_db.spend or 0.0
if reset_to > current_spend:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": f"reset_to ({reset_to}) must be <= current spend ({current_spend})"},
)

max_budget = key_in_db.max_budget
if key_in_db.litellm_budget_table is not None:
budget_max_budget = getattr(key_in_db.litellm_budget_table, "max_budget", None)
if budget_max_budget is not None:
if max_budget is None or budget_max_budget < max_budget:
max_budget = budget_max_budget

if max_budget is not None and reset_to > max_budget:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": f"reset_to ({reset_to}) must be <= budget ({max_budget})"},
)

return reset_to


@router.post(
"/key/{key:path}/reset_spend",
tags=["key management"],
dependencies=[Depends(user_api_key_auth)],
)
@management_endpoint_wrapper
async def reset_key_spend_fn(
key: str,
data: ResetSpendRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
litellm_changed_by: Optional[str] = Header(
None,
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
),
) -> Dict[str, Any]:
try:
from litellm.proxy.proxy_server import (
hash_token,
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)

if prisma_client is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": "DB not connected. prisma_client is None"},
)

if "sk" not in key:
hashed_api_key = key
else:
hashed_api_key = hash_token(key)

_key_in_db = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed_api_key},
include={"litellm_budget_table": True},
)
if _key_in_db is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"error": f"Key {key} not found."},
)

current_spend = _key_in_db.spend or 0.0
reset_to = _validate_reset_spend_value(data.reset_to, _key_in_db)

await _check_proxy_or_team_admin_for_key(
key_in_db=_key_in_db,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)

updated_key = await prisma_client.db.litellm_verificationtoken.update(
where={"token": hashed_api_key},
data={"spend": reset_to},
)

if updated_key is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": "Failed to update key spend"},
)

await _delete_cache_key_object(
hashed_token=hashed_api_key,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)

max_budget = updated_key.max_budget
budget_reset_at = updated_key.budget_reset_at

return {
"key_hash": hashed_api_key,
"spend": reset_to,
"previous_spend": current_spend,
"max_budget": max_budget,
"budget_reset_at": budget_reset_at,
}
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception("Error resetting key spend: %s", e)
raise handle_exception_on_proxy(e)


async def validate_key_list_check(
user_api_key_dict: UserAPIKeyAuth,
user_id: Optional[str],
Expand Down
Loading
Loading