diff --git a/Cargo.lock b/Cargo.lock index 8ce0ef66c877..f3d6daa74423 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2201,7 +2201,7 @@ dependencies = [ [[package]] name = "goose-bench" -version = "1.0.10" +version = "1.0.12" dependencies = [ "anyhow", "async-trait", diff --git a/crates/goose/src/agents/summarize.rs b/crates/goose/src/agents/summarize.rs index aaa41e99de1b..6be0f1d1aeb4 100644 --- a/crates/goose/src/agents/summarize.rs +++ b/crates/goose/src/agents/summarize.rs @@ -4,6 +4,7 @@ use async_trait::async_trait; use futures::stream::BoxStream; use std::collections::HashMap; +use std::sync::Arc; use tokio::sync::mpsc; use tokio::sync::Mutex; use tracing::{debug, error, instrument, warn}; @@ -20,6 +21,7 @@ use crate::providers::base::Provider; use crate::providers::base::ProviderUsage; use crate::providers::errors::ProviderError; use crate::register_agent; +use crate::session; use crate::token_counter::TokenCounter; use crate::truncate::{truncate_messages, OldestFirstTruncation}; use anyhow::{anyhow, Result}; @@ -164,6 +166,7 @@ impl Agent for SummarizeAgent { async fn reply( &self, messages: &[Message], + session_id: Option, ) -> anyhow::Result>> { let mut messages = messages.to_vec(); let reply_span = tracing::Span::current(); @@ -240,7 +243,19 @@ impl Agent for SummarizeAgent { &tools, ).await { Ok((response, usage)) => { - capabilities.record_usage(usage).await; + capabilities.record_usage(usage.clone()).await; + + // record usage for the session in the session file + if let Some(session_id) = session_id.clone() { + // TODO: track session_id in langfuse tracing + let session_file = session::get_path(session_id); + let mut metadata = session::read_metadata(&session_file)?; + metadata.total_tokens = usage.usage.total_tokens; + // The message count is the number of messages in the session + 1 for the response + // The message count does not include the tool response till next iteration + metadata.message_count = messages.len() + 1; + session::update_metadata(&session_file, &metadata).await?; + } // Reset truncation attempt truncation_attempt = 0; @@ -452,6 +467,11 @@ impl Agent for SummarizeAgent { Err(anyhow!("Prompt '{}' not found", name)) } + + async fn provider(&self) -> Arc> { + let capabilities = self.capabilities.lock().await; + capabilities.provider() + } } register_agent!("summarize", SummarizeAgent); diff --git a/ui/desktop/src/components/settings_v2/providers/ProviderSettingsPage.tsx b/ui/desktop/src/components/settings_v2/providers/ProviderSettingsPage.tsx index d9131f8f39d7..e368e85e28f6 100644 --- a/ui/desktop/src/components/settings_v2/providers/ProviderSettingsPage.tsx +++ b/ui/desktop/src/components/settings_v2/providers/ProviderSettingsPage.tsx @@ -47,6 +47,12 @@ const fakeProviderState: ProviderState[] = [ isConfigured: false, metadata: { location: null }, }, + { + id: 'gcp_vertex_ai', + name: 'GCP Vertex AI', + isConfigured: true, + metadata: { location: null }, + }, ]; export default function ProviderSettings({ onClose }: { onClose: () => void }) { diff --git a/ui/desktop/src/components/settings_v2/providers/modal/subcomponents/ProviderLogo.tsx b/ui/desktop/src/components/settings_v2/providers/modal/subcomponents/ProviderLogo.tsx index df947e26c696..d9e1faa884ff 100644 --- a/ui/desktop/src/components/settings_v2/providers/modal/subcomponents/ProviderLogo.tsx +++ b/ui/desktop/src/components/settings_v2/providers/modal/subcomponents/ProviderLogo.tsx @@ -6,6 +6,7 @@ import GroqLogo from './icons/groq@3x.png'; import OllamaLogo from './icons/ollama@3x.png'; import DatabricksLogo from './icons/databricks@3x.png'; import OpenRouterLogo from './icons/openrouter@3x.png'; +import DefaultLogo from './icons/default@3x.png'; // Map provider names to their logos const providerLogos = { @@ -16,12 +17,13 @@ const providerLogos = { ollama: OllamaLogo, databricks: DatabricksLogo, openrouter: OpenRouterLogo, + default: DefaultLogo, }; export default function ProviderLogo({ providerName }) { // Convert provider name to lowercase and fetch the logo const logoKey = providerName.toLowerCase(); - const logo = providerLogos[logoKey] || OpenAILogo; // TODO: need default icon + const logo = providerLogos[logoKey] || DefaultLogo; return (
diff --git a/ui/desktop/src/components/settings_v2/providers/modal/subcomponents/icons/default.png b/ui/desktop/src/components/settings_v2/providers/modal/subcomponents/icons/default.png new file mode 100644 index 000000000000..c37e6d0bd1bc Binary files /dev/null and b/ui/desktop/src/components/settings_v2/providers/modal/subcomponents/icons/default.png differ diff --git a/ui/desktop/src/components/settings_v2/providers/modal/subcomponents/icons/default@2x.png b/ui/desktop/src/components/settings_v2/providers/modal/subcomponents/icons/default@2x.png new file mode 100644 index 000000000000..8825cb0128f2 Binary files /dev/null and b/ui/desktop/src/components/settings_v2/providers/modal/subcomponents/icons/default@2x.png differ diff --git a/ui/desktop/src/components/settings_v2/providers/modal/subcomponents/icons/default@3x.png b/ui/desktop/src/components/settings_v2/providers/modal/subcomponents/icons/default@3x.png new file mode 100644 index 000000000000..f46b2db4851d Binary files /dev/null and b/ui/desktop/src/components/settings_v2/providers/modal/subcomponents/icons/default@3x.png differ