Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 64 additions & 8 deletions litellm/proxy/management_endpoints/cost_tracking_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
POST /cost/estimate - Estimate cost for a given model and token counts
"""

from typing import Dict, Union
from typing import Dict, Optional, Tuple, Union

from fastapi import APIRouter, Depends, HTTPException

Expand All @@ -29,6 +29,52 @@
router = APIRouter()


def _resolve_model_for_cost_lookup(model: str) -> Tuple[str, Optional[str]]:
"""
Resolve a model name (which may be a router alias/model_group) to the
underlying litellm model name for cost lookup.

Args:
model: The model name from the request (could be a router alias like 'e-model-router'
or an actual model name like 'azure_ai/gpt-4')

Returns:
Tuple of (resolved_model_name, custom_llm_provider)
- resolved_model_name: The actual model name to use for cost lookup
- custom_llm_provider: The provider if resolved from router, None otherwise
"""
from litellm.proxy.proxy_server import llm_router

custom_llm_provider: Optional[str] = None

# Try to resolve from router if available
if llm_router is not None:
try:
# Get deployments for this model name (handles aliases, wildcards, etc.)
deployments = llm_router.get_model_list(model_name=model)

if deployments and len(deployments) > 0:
# Get the first deployment's litellm model
first_deployment = deployments[0]
litellm_params = first_deployment.get("litellm_params", {})
resolved_model = litellm_params.get("model")

if resolved_model:
verbose_proxy_logger.debug(
f"Resolved model '{model}' to '{resolved_model}' from router"
)
# Extract custom_llm_provider if present
custom_llm_provider = litellm_params.get("custom_llm_provider")
return resolved_model, custom_llm_provider
except Exception as e:
verbose_proxy_logger.debug(
f"Could not resolve model '{model}' from router: {e}"
)

# Return original model if not resolved
return model, custom_llm_provider


def _calculate_period_costs(
num_requests, cost_per_request, input_cost, output_cost, margin_cost
):
Expand Down Expand Up @@ -413,12 +459,18 @@ async def estimate_cost(
```
"""
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.utils import Usage
from litellm.utils import ModelResponse
from litellm.types.utils import ModelResponse, Usage

# Resolve model name (handles router aliases like 'e-model-router' -> 'azure_ai/gpt-4')
resolved_model, resolved_provider = _resolve_model_for_cost_lookup(request.model)

verbose_proxy_logger.debug(
f"Cost estimate: request.model='{request.model}' resolved to '{resolved_model}'"
)

# Create a mock response with usage for completion_cost
mock_response = ModelResponse(
model=request.model,
model=resolved_model,
usage=Usage(
prompt_tokens=request.input_tokens,
completion_tokens=request.output_tokens,
Expand All @@ -428,7 +480,7 @@ async def estimate_cost(

# Create a logging object to capture cost breakdown
litellm_logging_obj = LiteLLMLoggingObj(
model=request.model,
model=resolved_model,
messages=[],
stream=False,
call_type="completion",
Expand All @@ -441,14 +493,14 @@ async def estimate_cost(
try:
cost_per_request = completion_cost(
completion_response=mock_response,
model=request.model,
model=resolved_model,
litellm_logging_obj=litellm_logging_obj,
)
except Exception as e:
raise HTTPException(
status_code=404,
detail={
"error": f"Could not calculate cost for model '{request.model}': {str(e)}"
"error": f"Could not calculate cost for model '{request.model}' (resolved to '{resolved_model}'): {str(e)}"
},
)

Expand All @@ -461,7 +513,7 @@ async def estimate_cost(

# Get model info for per-token pricing display
try:
model_info = litellm.get_model_info(model=request.model)
model_info = litellm.get_model_info(model=resolved_model)
input_cost_per_token = model_info.get("input_cost_per_token")
output_cost_per_token = model_info.get("output_cost_per_token")
custom_llm_provider = model_info.get("litellm_provider")
Expand All @@ -470,6 +522,10 @@ async def estimate_cost(
output_cost_per_token = None
custom_llm_provider = None

# Use provider from router resolution if not found in model_info
if custom_llm_provider is None and resolved_provider is not None:
custom_llm_provider = resolved_provider

# Calculate daily and monthly costs
daily_cost, daily_input_cost, daily_output_cost, daily_margin_cost = (
_calculate_period_costs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,67 @@ async def test_estimate_cost_model_not_found(self):
)

assert exc_info.value.status_code == 404

@pytest.mark.asyncio
async def test_estimate_cost_resolves_router_model_alias(self):
"""
Test that estimate_cost resolves router model aliases to underlying models.

When a user selects a model like 'my-gpt4-alias' from the UI (which is a
router model_name), the endpoint should resolve it to the actual model
(e.g., 'azure/gpt-4') for cost calculation.

This prevents the bug where custom model names fail cost lookup because
they aren't in model_prices_and_context_window.json.
"""
request = CostEstimateRequest(
model="my-gpt4-alias", # Router alias, not actual model name
input_tokens=1000,
output_tokens=500,
)

# Mock the router to return deployment info
mock_router = MagicMock()
mock_router.get_model_list.return_value = [
{
"model_name": "my-gpt4-alias",
"litellm_params": {
"model": "azure/gpt-4", # Actual model for pricing
"custom_llm_provider": "azure",
},
}
]

with patch(
"litellm.proxy.proxy_server.llm_router",
mock_router,
):
with patch(
"litellm.proxy.management_endpoints.cost_tracking_settings.completion_cost"
) as mock_completion_cost:
mock_completion_cost.return_value = 0.05

with patch("litellm.get_model_info") as mock_get_model_info:
mock_get_model_info.return_value = {
"input_cost_per_token": 0.00003,
"output_cost_per_token": 0.00006,
"litellm_provider": "azure",
}

response = await estimate_cost(
request=request,
user_api_key_dict=MagicMock(),
)

# Verify router was queried for the alias
mock_router.get_model_list.assert_called_with(model_name="my-gpt4-alias")

# Verify completion_cost was called with RESOLVED model, not the alias
call_args = mock_completion_cost.call_args
assert call_args.kwargs["model"] == "azure/gpt-4"

# Verify response contains original model name (for UI display)
assert response.model == "my-gpt4-alias"
assert response.cost_per_request == 0.05
assert response.provider == "azure"

Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,13 @@ const MultiCostResults: React.FC<MultiCostResultsProps> = ({ multiResult, timePe

const validEntries = multiResult.entries.filter((e) => e.result !== null);
const loadingEntries = multiResult.entries.filter((e) => e.loading);
const errorEntries = multiResult.entries.filter((e) => e.error !== null);
const hasAnyResult = validEntries.length > 0;
const isAnyLoading = loadingEntries.length > 0;
const hasAnyError = errorEntries.length > 0;

if (!hasAnyResult && !isAnyLoading) {
// Show empty state only if no results, not loading, and no errors
if (!hasAnyResult && !isAnyLoading && !hasAnyError) {
return (
<div className="py-6 text-center border border-dashed border-gray-300 rounded-lg bg-gray-50">
<Text className="text-gray-500">
Expand All @@ -126,7 +129,8 @@ const MultiCostResults: React.FC<MultiCostResultsProps> = ({ multiResult, timePe
);
}

if (!hasAnyResult && isAnyLoading) {
// Show loading state only if loading and no results/errors yet
if (!hasAnyResult && isAnyLoading && !hasAnyError) {
return (
<div className="py-6 text-center">
<Spin indicator={<LoadingOutlined spin />} />
Expand All @@ -135,6 +139,26 @@ const MultiCostResults: React.FC<MultiCostResultsProps> = ({ multiResult, timePe
);
}

// Show errors-only view when there are errors but no valid results
if (!hasAnyResult && hasAnyError) {
return (
<div className="space-y-4">
<Divider className="my-4" />
<div className="flex items-center justify-between">
<Text className="text-base font-semibold text-gray-900">Cost Estimates</Text>
{isAnyLoading && <Spin indicator={<LoadingOutlined spin />} size="small" />}
</div>
{/* Error Messages */}
{errorEntries.map((e) => (
<div key={e.entry.id} className="text-sm text-red-600 bg-red-50 p-3 rounded-lg border border-red-200">
<span className="font-medium">{e.entry.model || "Unknown model"}: </span>
{e.error}
</div>
))}
</div>
);
}

const toggleExpanded = (id: string) => {
setExpandedModels((prev) => {
const next = new Set(prev);
Expand All @@ -157,13 +181,28 @@ const MultiCostResults: React.FC<MultiCostResultsProps> = ({ multiResult, timePe
title: "Model",
dataIndex: "model",
key: "model",
render: (text: string, record: { id: string; provider?: string | null }) => (
<div className="flex items-center gap-2">
<span className="font-medium text-sm">{text}</span>
{record.provider && (
<Tag color="blue" className="text-xs">
{record.provider}
</Tag>
render: (text: string, record: { id: string; provider?: string | null; error?: string | null; loading?: boolean; hasZeroCost?: boolean }) => (
<div className="flex flex-col gap-1">
<div className="flex items-center gap-2">
<span className="font-medium text-sm">{text}</span>
{record.provider && (
<Tag color="blue" className="text-xs">
{record.provider}
</Tag>
)}
{record.loading && (
<Spin indicator={<LoadingOutlined spin />} size="small" />
)}
</div>
{record.error && (
<div className="text-xs text-red-600 bg-red-50 px-2 py-1 rounded">
⚠️ {record.error}
</div>
)}
{record.hasZeroCost && !record.error && (
<div className="text-xs text-amber-600 bg-amber-50 px-2 py-1 rounded">
⚠️ No pricing data found for this model. Set base_model in config.
</div>
)}
</div>
),
Expand All @@ -173,52 +212,65 @@ const MultiCostResults: React.FC<MultiCostResultsProps> = ({ multiResult, timePe
dataIndex: "cost_per_request",
key: "cost_per_request",
align: "right" as const,
render: (value: number) => <span className="font-mono text-sm">{formatCost(value)}</span>,
render: (value: number | null, record: { error?: string | null }) => (
record.error ? <span className="text-gray-400">-</span> : <span className="font-mono text-sm">{formatCost(value)}</span>
),
},
{
title: "Margin Fee",
dataIndex: "margin_cost_per_request",
key: "margin_cost_per_request",
align: "right" as const,
render: (value: number) => (
<span className={`font-mono text-sm ${value > 0 ? "text-amber-600" : "text-gray-400"}`}>
{formatCost(value)}
</span>
render: (value: number | null, record: { error?: string | null }) => (
record.error ? <span className="text-gray-400">-</span> : (
<span className={`font-mono text-sm ${(value ?? 0) > 0 ? "text-amber-600" : "text-gray-400"}`}>
{formatCost(value)}
</span>
)
),
},
{
title: periodLabel,
dataIndex: periodCostKey,
key: "period_cost",
align: "right" as const,
render: (value: number | null) => <span className="font-mono text-sm">{formatCost(value)}</span>,
render: (value: number | null, record: { error?: string | null }) => (
record.error ? <span className="text-gray-400">-</span> : <span className="font-mono text-sm">{formatCost(value)}</span>
),
},
{
title: "",
key: "expand",
width: 40,
render: (_: unknown, record: { id: string }) => (
<Button
size="xs"
variant="light"
onClick={() => toggleExpanded(record.id)}
className="text-gray-400 hover:text-gray-600"
>
{expandedModels.has(record.id) ? <DownOutlined /> : <RightOutlined />}
</Button>
render: (_: unknown, record: { id: string; error?: string | null }) => (
record.error ? null : (
<Button
size="xs"
variant="light"
onClick={() => toggleExpanded(record.id)}
className="text-gray-400 hover:text-gray-600"
>
{expandedModels.has(record.id) ? <DownOutlined /> : <RightOutlined />}
</Button>
)
),
},
];

const summaryData = validEntries.map((e) => ({
// Include both valid results and errors in the table data
const allEntriesWithModels = multiResult.entries.filter((e) => e.entry.model);
const summaryData = allEntriesWithModels.map((e) => ({
key: e.entry.id,
id: e.entry.id,
model: e.result!.model,
provider: e.result!.provider,
cost_per_request: e.result!.cost_per_request,
margin_cost_per_request: e.result!.margin_cost_per_request,
daily_cost: e.result!.daily_cost,
monthly_cost: e.result!.monthly_cost,
model: e.result?.model || e.entry.model,
provider: e.result?.provider,
cost_per_request: e.result?.cost_per_request ?? null,
margin_cost_per_request: e.result?.margin_cost_per_request ?? null,
daily_cost: e.result?.daily_cost ?? null,
monthly_cost: e.result?.monthly_cost ?? null,
error: e.error,
loading: e.loading,
hasZeroCost: e.result && e.result.cost_per_request === 0,
}));

return (
Expand Down Expand Up @@ -268,7 +320,7 @@ const MultiCostResults: React.FC<MultiCostResultsProps> = ({ multiResult, timePe
</Card>

{/* Per-Model Table */}
{validEntries.length > 0 && (
{summaryData.length > 0 && (
<Table
columns={summaryColumns}
dataSource={summaryData}
Expand All @@ -290,16 +342,6 @@ const MultiCostResults: React.FC<MultiCostResultsProps> = ({ multiResult, timePe
}}
/>
)}

{/* Error Messages */}
{multiResult.entries
.filter((e) => e.error)
.map((e) => (
<div key={e.entry.id} className="text-sm text-red-600 bg-red-50 p-3 rounded-lg border border-red-200">
<span className="font-medium">{e.entry.model || "Unknown model"}: </span>
{e.error}
</div>
))}
</div>
);
};
Expand Down
Loading