Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions crates/goose-cli/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use goose::agents::extension::{Envs, ExtensionConfig};
use goose::agents::{Agent, SessionConfig};
use goose::config::Config;
use goose::message::{Message, MessageContent};
use goose::providers::pricing::initialize_pricing_cache;
use goose::session;
use input::InputResult;
use mcp_core::handler::ToolError;
Expand Down Expand Up @@ -1305,11 +1306,40 @@ impl Session {
let model_config = provider.get_model_config();
let context_limit = model_config.context_limit();

let config = Config::global();
let show_cost = config
.get_param::<bool>("GOOSE_CLI_SHOW_COST")
.unwrap_or(false);

let provider_name = config
.get_param::<String>("GOOSE_PROVIDER")
.unwrap_or_else(|_| "unknown".to_string());

// Initialize pricing cache on startup
tracing::info!("Initializing pricing cache...");
if let Err(e) = initialize_pricing_cache().await {
tracing::warn!(
"Failed to initialize pricing cache: {e}. Pricing data may not be available."
);
}

match self.get_metadata() {
Ok(metadata) => {
let total_tokens = metadata.total_tokens.unwrap_or(0) as usize;

output::display_context_usage(total_tokens, context_limit);

if show_cost {
let input_tokens = metadata.input_tokens.unwrap_or(0) as usize;
let output_tokens = metadata.output_tokens.unwrap_or(0) as usize;
output::display_cost_usage(
&provider_name,
&model_config.model_name,
input_tokens,
output_tokens,
)
.await;
}
}
Err(_) => {
output::display_context_usage(0, context_limit);
Expand Down
64 changes: 64 additions & 0 deletions crates/goose-cli/src/session/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ use bat::WrappingMode;
use console::{style, Color};
use goose::config::Config;
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
use goose::providers::pricing::get_model_pricing;
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use mcp_core::prompt::PromptArgument;
use mcp_core::tool::ToolCall;
use regex::Regex;
use serde_json::Value;
use std::cell::RefCell;
use std::collections::HashMap;
Expand Down Expand Up @@ -668,6 +670,68 @@ pub fn display_context_usage(total_tokens: usize, context_limit: usize) {
);
}

fn normalize_model_name(model: &str) -> String {
let mut result = model.to_string();

// Remove "-latest" suffix
if result.ends_with("-latest") {
result = result.strip_suffix("-latest").unwrap().to_string();
}

// Remove date-like suffixes: -YYYYMMDD
let re_date = Regex::new(r"-\d{8}$").unwrap();
if re_date.is_match(&result) {
result = re_date.replace(&result, "").to_string();
}

// Convert version numbers like -3-5- to -3.5- (e.g., claude-3-5-haiku -> claude-3.5-haiku)
let re_version = Regex::new(r"-(\d+)-(\d+)-").unwrap();
if re_version.is_match(&result) {
result = re_version.replace(&result, "-$1.$2-").to_string();
}

result
}

async fn estimate_cost_usd(
provider: &str,
model: &str,
input_tokens: usize,
output_tokens: usize,
) -> Option<f64> {
// Use the pricing module's get_model_pricing which handles model name mapping internally
let cleaned_model = normalize_model_name(model);
let pricing_info = get_model_pricing(provider, &cleaned_model).await;

match pricing_info {
Some(pricing) => {
let input_cost = pricing.input_cost * input_tokens as f64;
let output_cost = pricing.output_cost * output_tokens as f64;
Some(input_cost + output_cost)
}
None => None,
}
}

/// Display cost information, if price data is available.
pub async fn display_cost_usage(
provider: &str,
model: &str,
input_tokens: usize,
output_tokens: usize,
) {
if let Some(cost) = estimate_cost_usd(provider, model, input_tokens, output_tokens).await {
use console::style;
println!(
"Cost: {} USD ({} tokens: in {}, out {})",
style(format!("${:.4}", cost)).cyan(),
input_tokens + output_tokens,
input_tokens,
output_tokens
);
}
}

pub struct McpSpinners {
bars: HashMap<String, ProgressBar>,
log_spinner: Option<ProgressBar>,
Expand Down