Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions litellm/proxy/management_endpoints/common_daily_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,14 +327,45 @@ async def get_api_key_metadata(
prisma_client: PrismaClient,
api_keys: Set[str],
) -> Dict[str, Dict[str, Any]]:
"""Update api key metadata for a single record."""
"""Get api key metadata, falling back to deleted keys table for keys not found in active table.

This ensures that key_alias and team_id are preserved in historical activity logs
even after a key is deleted or regenerated.
"""
key_records = await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": list(api_keys)}}
)
return {
k.token: {"key_alias": k.key_alias, "team_id": k.team_id} for k in key_records
result = {
k.token: {"key_alias": k.key_alias, "team_id": k.team_id}
for k in key_records
}

# For any keys not found in the active table, check the deleted keys table
missing_keys = api_keys - set(result.keys())
if missing_keys:
try:
deleted_key_records = (
await prisma_client.db.litellm_deletedverificationtoken.find_many(
where={"token": {"in": list(missing_keys)}},
order={"deleted_at": "desc"},
)
)
# Use the most recent deleted record for each token (ordered by deleted_at desc)
for k in deleted_key_records:
if k.token not in result:
result[k.token] = {
"key_alias": k.key_alias,
"team_id": k.team_id,
}
except Exception as e:
verbose_proxy_logger.warning(
"Failed to fetch deleted key metadata for %d missing keys: %s",
len(missing_keys),
e,
)

return result


def _adjust_dates_for_timezone(
start_date: str,
Expand Down
139 changes: 80 additions & 59 deletions litellm/proxy/management_endpoints/key_management_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from litellm.proxy.spend_tracking.spend_tracking_utils import _is_master_key
from litellm.proxy.utils import (
PrismaClient,
ProxyLogging,
_hash_token_if_needed,
handle_exception_on_proxy,
is_valid_api_key,
Expand Down Expand Up @@ -3180,6 +3181,63 @@ def get_new_token(data: Optional[RegenerateKeyRequest]) -> str:
return new_token


async def _execute_virtual_key_regeneration(
*,
prisma_client: PrismaClient,
key_in_db: LiteLLM_VerificationToken,
hashed_api_key: str,
key: str,
data: Optional[RegenerateKeyRequest],
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str],
user_api_key_cache: DualCache,
proxy_logging_obj: ProxyLogging,
) -> GenerateKeyResponse:
"""Generate new token, update DB, invalidate cache, and return response."""
from litellm.proxy.proxy_server import hash_token

new_token = get_new_token(data=data)
new_token_hash = hash_token(new_token)
new_token_key_name = f"sk-...{new_token[-4:]}"
update_data = {"token": new_token_hash, "key_name": new_token_key_name}

non_default_values = {}
if data is not None:
non_default_values = await prepare_key_update_data(
data=data, existing_key_row=key_in_db
)
verbose_proxy_logger.debug("non_default_values: %s", non_default_values)
update_data.update(non_default_values)
update_data = prisma_client.jsonify_object(data=update_data)

updated_token = await prisma_client.db.litellm_verificationtoken.update(
where={"token": hashed_api_key},
data=update_data, # type: ignore
)
updated_token_dict = dict(updated_token) if updated_token is not None else {}
updated_token_dict["key"] = new_token
updated_token_dict["token_id"] = updated_token_dict.pop("token")

if hashed_api_key or key:
await _delete_cache_key_object(
hashed_token=hash_token(key),
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)

response = GenerateKeyResponse(**updated_token_dict)
asyncio.create_task(
KeyManagementEventHooks.async_key_rotated_hook(
data=data,
existing_key_row=key_in_db,
response=response,
user_api_key_dict=user_api_key_dict,
litellm_changed_by=litellm_changed_by,
)
)
return response


@router.post(
"/key/{key:path}/regenerate",
tags=["key management"],
Expand Down Expand Up @@ -3348,69 +3406,32 @@ async def regenerate_key_fn( # noqa: PLR0915
)
verbose_proxy_logger.debug("key_in_db: %s", _key_in_db)

new_token = get_new_token(data=data)

new_token_hash = hash_token(new_token)
new_token_key_name = f"sk-...{new_token[-4:]}"

# Prepare the update data
update_data = {
"token": new_token_hash,
"key_name": new_token_key_name,
}

non_default_values = {}
if data is not None:
# Update with any provided parameters from GenerateKeyRequest
non_default_values = await prepare_key_update_data(
data=data, existing_key_row=_key_in_db
)
verbose_proxy_logger.debug("non_default_values: %s", non_default_values)

update_data.update(non_default_values)
update_data = prisma_client.jsonify_object(data=update_data)
# Update the token in the database
updated_token = await prisma_client.db.litellm_verificationtoken.update(
where={"token": hashed_api_key},
data=update_data, # type: ignore
)

updated_token_dict = {}
if updated_token is not None:
updated_token_dict = dict(updated_token)

updated_token_dict["key"] = new_token
updated_token_dict["token_id"] = updated_token_dict.pop("token")

### 3. remove existing key entry from cache
######################################################################

if hashed_api_key or key:
await _delete_cache_key_object(
hashed_token=hash_token(key),
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# Normalize litellm_changed_by: if it's a Header object or not a string, convert to None
if litellm_changed_by is not None and not isinstance(litellm_changed_by, str):
litellm_changed_by = None

response = GenerateKeyResponse(
**updated_token_dict,
# Save the old key record to deleted table before regeneration.
# This preserves key_alias and team_id metadata for historical spend records.
# If this fails, abort the regeneration to avoid permanently losing the
# old hash→metadata mapping.
await _persist_deleted_verification_tokens(
keys=[_key_in_db],
prisma_client=prisma_client,
user_api_key_dict=user_api_key_dict,
litellm_changed_by=litellm_changed_by,
)

verbose_proxy_logger.info(
"Key regeneration completed: key_alias=%s",
getattr(_key_in_db, "key_alias", None),
)
asyncio.create_task(
KeyManagementEventHooks.async_key_rotated_hook(
data=data,
existing_key_row=_key_in_db,
response=response,
user_api_key_dict=user_api_key_dict,
litellm_changed_by=litellm_changed_by,
)
return await _execute_virtual_key_regeneration(
prisma_client=prisma_client,
key_in_db=_key_in_db,
hashed_api_key=hashed_api_key,
key=key,
data=data,
user_api_key_dict=user_api_key_dict,
litellm_changed_by=litellm_changed_by,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)

return response
except Exception as e:
verbose_proxy_logger.exception("Error regenerating key: %s", e)
raise handle_exception_on_proxy(e)
Expand Down
Loading
Loading