Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions enterprise/litellm_enterprise/proxy/hooks/managed_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,168 @@ async def afile_list(
"""Handled in files_endpoints.py"""
return []

def _is_batch_polling_enabled(self) -> bool:
"""
Check if batch cost tracking is actually enabled and running.
Returns:
bool: True if batch cost tracking is active, False otherwise
"""
try:
# Import here to avoid circular dependencies
import litellm.proxy.proxy_server as proxy_server_module
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inline import inside a method

Per the project's code style guidelines (CLAUDE.md): "Avoid imports within methods — place all imports at the top of the file (module-level)." This inline import of litellm.proxy.proxy_server should be moved to the module level if possible, or documented as a necessary circular-dependency workaround.

Note: litellm.proxy.proxy_server is not currently imported at the top of this file. If the circular dependency concern is valid, consider adding a comment explaining why this must be inline.

Context Used: Context from dashboard - CLAUDE.md (source)


# Check if the scheduler has the batch cost checking job registered
scheduler = getattr(proxy_server_module, 'scheduler', None)
if scheduler is None:
return False

# Check if the check_batch_cost_job exists in the scheduler
try:
job = scheduler.get_job('check_batch_cost_job')
if job is not None:
return True
except Exception:
# Job not found or scheduler doesn't support get_job
pass

return False
except Exception as e:
verbose_logger.warning(
f"Error checking batch polling configuration: {e}. Assuming disabled."
)
return False

async def _get_batches_referencing_file(
self, file_id: str
) -> List[Dict[str, Any]]:
"""
Find batches in non-terminal states that reference this file.

Non-terminal states: validating, in_progress, finalizing
Terminal states: completed, complete, failed, expired, cancelled

Args:
file_id: The unified file ID to check

Returns:
List of batch objects referencing this file in non-terminal state
(max 10 for error message display)
"""
# Prepare list of file IDs to check (both unified and provider IDs)
file_ids_to_check = [file_id]

# Get model-specific file IDs for this unified file ID if it's a managed file
try:
model_file_id_mapping = await self.get_model_file_id_mapping(
[file_id], litellm_parent_otel_span=None
Comment on lines +1106 to +1107
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate get_model_file_id_mapping call on deletion path

get_model_file_id_mapping is called here (line 1106) inside _get_batches_referencing_file, and then called again at line 1237 inside afile_delete after the check passes. Each call hits the cache or DB. Consider passing the resolved mapping as a parameter to avoid the redundant lookup, or restructuring so afile_delete resolves the mapping once and passes it into the deletion check.

Context Used: Rule from dashboard - What: Avoid creating new database requests or Router objects in the critical request path.

Why: Cre... (source)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +1106 to +1107
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate get_model_file_id_mapping call on delete path

_get_batches_referencing_file calls self.get_model_file_id_mapping() at line 1106 to resolve provider file IDs. Then, after _check_file_deletion_allowed returns (no block), afile_delete calls self.get_model_file_id_mapping() again at line 1237. This results in two cache/DB lookups for the same file ID on every deletion request. Consider passing the result from the first call through or caching it on the instance to avoid the redundant lookup.

Context Used: Rule from dashboard - What: Avoid creating new database requests or Router objects in the critical request path.

Why: Cre... (source)

)

if model_file_id_mapping and file_id in model_file_id_mapping:
# Add all provider file IDs for this unified file
provider_file_ids = list(model_file_id_mapping[file_id].values())
file_ids_to_check.extend(provider_file_ids)
except Exception as e:
verbose_logger.debug(
f"Could not get model file ID mapping for {file_id}: {e}. "
f"Will only check unified file ID."
)
MAX_MATCHES_TO_RETURN = 10

batches = await self.prisma_client.db.litellm_managedobjecttable.find_many(
where={
"file_purpose": "batch",
"status": {"in": ["validating", "in_progress", "finalizing"]},
},
take=MAX_MATCHES_TO_RETURN,
order={"created_at": "desc"},
)

referencing_batches = []
for batch in batches:
try:
# Parse the batch file_object to check for file references
batch_data = json.loads(batch.file_object) if isinstance(batch.file_object, str) else batch.file_object

# Extract file IDs from batch
# Batches typically reference the unified file ID in input_file_id
# Output and error files are generated by the provider
input_file_id = batch_data.get("input_file_id")
output_file_id = batch_data.get("output_file_id")
error_file_id = batch_data.get("error_file_id")

referenced_file_ids = [fid for fid in [input_file_id, output_file_id, error_file_id] if fid]

# Check if any referenced file ID matches the file we're trying to delete
if any(ref_id in file_ids_to_check for ref_id in referenced_file_ids):
referencing_batches.append({
"batch_id": batch.unified_object_id,
"status": batch.status,
"created_at": batch.created_at,
})
except Exception as e:
verbose_logger.warning(
f"Error parsing batch object {batch.unified_object_id}: {e}"
)
continue

return referencing_batches

async def _check_file_deletion_allowed(self, file_id: str) -> None:
"""
Check if file deletion should be blocked due to batch references.

Blocks deletion if:
1. File is referenced by any batch in non-terminal state, AND
2. Batch polling is configured (user wants cost tracking)

Args:
file_id: The unified file ID to check

Raises:
HTTPException: If file deletion should be blocked
"""
# Check if batch polling is enabled
if not self._is_batch_polling_enabled():
# Batch polling not configured, allow deletion
return

# Check if file is referenced by any non-terminal batches
referencing_batches = await self._get_batches_referencing_file(file_id)

if referencing_batches:
# File is referenced by non-terminal batches and polling is enabled
MAX_BATCHES_IN_ERROR = 5 # Limit batches shown in error message for readability

# Show up to MAX_BATCHES_IN_ERROR in the error message
batches_to_show = referencing_batches[:MAX_BATCHES_IN_ERROR]
batch_statuses = [f"{b['batch_id']}: {b['status']}" for b in batches_to_show]

# Determine the count message
count_message = f"{len(referencing_batches)}"
if len(referencing_batches) >= 10: # MAX_MATCHES_TO_RETURN from _get_batches_referencing_file
count_message = "10+"

error_message = (
f"Cannot delete file {file_id}. "
f"The file is referenced by {count_message} batch(es) in non-terminal state"
)

# Add specific batch details if not too many
if len(referencing_batches) <= MAX_BATCHES_IN_ERROR:
error_message += f": {', '.join(batch_statuses)}. "
else:
error_message += f" (showing {MAX_BATCHES_IN_ERROR} most recent): {', '.join(batch_statuses)}. "

error_message += (
f"To delete this file before complete cost tracking, please delete or cancel the referencing batch(es) first. "
f"Alternatively, wait for all batches to complete processing."
)

raise HTTPException(
status_code=400,
detail=error_message,
)

async def afile_delete(
self,
file_id: str,
Expand All @@ -1059,6 +1221,9 @@ async def afile_delete(
**data: Dict,
) -> OpenAIFileObject:

# Check if file deletion should be blocked due to batch references
await self._check_file_deletion_allowed(file_id)

# file_id = convert_b64_uid_to_unified_uid(file_id)
model_file_id_mapping = await self.get_model_file_id_mapping(
[file_id], litellm_parent_otel_span
Expand Down
Loading
Loading