diff --git a/crates/goose-server/src/openapi.rs b/crates/goose-server/src/openapi.rs index f84a92a64247..663ca2cbe474 100644 --- a/crates/goose-server/src/openapi.rs +++ b/crates/goose-server/src/openapi.rs @@ -19,8 +19,8 @@ use goose::config::declarative_providers::{ }; use goose::conversation::message::{ FrontendToolRequest, Message, MessageContent, MessageMetadata, RedactedThinkingContent, - SystemNotificationContent, SystemNotificationType, ThinkingContent, ToolConfirmationRequest, - ToolRequest, ToolResponse, + SystemNotificationContent, SystemNotificationType, ThinkingContent, TokenState, + ToolConfirmationRequest, ToolRequest, ToolResponse, }; use crate::routes::reply::MessageEvent; @@ -402,6 +402,7 @@ derive_utoipa!(Icon as IconSchema); Message, MessageContent, MessageMetadata, + TokenState, ContentSchema, EmbeddedResourceSchema, ImageContentSchema, diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 4c4a81925c57..0748399991be 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -8,7 +8,7 @@ use axum::{ }; use bytes::Bytes; use futures::{stream::StreamExt, Stream}; -use goose::conversation::message::{Message, MessageContent}; +use goose::conversation::message::{Message, MessageContent, TokenState}; use goose::conversation::Conversation; use goose::permission::{Permission, PermissionConfirmation}; use goose::session::SessionManager; @@ -126,6 +126,7 @@ impl IntoResponse for SseResponse { pub enum MessageEvent { Message { message: Message, + token_state: TokenState, }, Error { error: String, @@ -159,6 +160,7 @@ async fn stream_event( e ) }); + if tx.send(format!("data: {}\n\n", json)).await.is_err() { tracing::info!("client hung up"); cancel_token.cancel(); @@ -305,7 +307,32 @@ pub async fn reply( } all_messages.push(message.clone()); - stream_event(MessageEvent::Message { message }, &tx, &cancel_token).await; + + let token_state = match SessionManager::get_session(&session_id, false).await { + Ok(session) => { + TokenState { + input_tokens: session.input_tokens.unwrap_or(0), + output_tokens: session.output_tokens.unwrap_or(0), + total_tokens: session.total_tokens.unwrap_or(0), + accumulated_input_tokens: session.accumulated_input_tokens.unwrap_or(0), + accumulated_output_tokens: session.accumulated_output_tokens.unwrap_or(0), + accumulated_total_tokens: session.accumulated_total_tokens.unwrap_or(0), + } + }, + Err(e) => { + tracing::warn!("Failed to fetch session for token state: {}", e); + TokenState { + input_tokens: 0, + output_tokens: 0, + total_tokens: 0, + accumulated_input_tokens: 0, + accumulated_output_tokens: 0, + accumulated_total_tokens: 0, + } + } + }; + + stream_event(MessageEvent::Message { message, token_state }, &tx, &cancel_token).await; } Ok(Some(Ok(AgentEvent::HistoryReplaced(new_messages)))) => { all_messages = new_messages.clone(); diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 0cf3515929aa..056886f93cf4 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -834,9 +834,11 @@ impl Agent { } } Err(e) => { - yield AgentEvent::Message(Message::assistant().with_text( - format!("Ran into this error trying to compact: {e}.\n\nPlease try again or create a new session") - )); + yield AgentEvent::Message( + Message::assistant().with_text( + format!("Ran into this error trying to compact: {e}.\n\nPlease try again or create a new session") + ) + ); } } })) @@ -926,7 +928,7 @@ impl Agent { if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { if final_output_tool.final_output.is_some() { let final_event = AgentEvent::Message( - Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()), + Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()) ); yield final_event; break; @@ -935,9 +937,11 @@ impl Agent { turns_taken += 1; if turns_taken > max_turns { - yield AgentEvent::Message(Message::assistant().with_text( - "I've reached the maximum number of actions I can do without user input. Would you like me to continue?" - )); + yield AgentEvent::Message( + Message::assistant().with_text( + "I've reached the maximum number of actions I can do without user input. Would you like me to continue?" + ) + ); break; } @@ -1187,18 +1191,22 @@ impl Agent { } Err(e) => { error!("Error: {}", e); - yield AgentEvent::Message(Message::assistant().with_text( + yield AgentEvent::Message( + Message::assistant().with_text( format!("Ran into this error trying to compact: {e}.\n\nPlease retry if you think this is a transient or recoverable error.") - )); + ) + ); break; } } } Err(e) => { error!("Error: {}", e); - yield AgentEvent::Message(Message::assistant().with_text( + yield AgentEvent::Message( + Message::assistant().with_text( format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.") - )); + ) + ); break; } } @@ -1233,9 +1241,11 @@ impl Agent { } Err(e) => { error!("Retry logic failed: {}", e); - yield AgentEvent::Message(Message::assistant().with_text( - format!("Retry logic encountered an error: {}", e) - )); + yield AgentEvent::Message( + Message::assistant().with_text( + format!("Retry logic encountered an error: {}", e) + ) + ); exit_chat = true; } } diff --git a/crates/goose/src/conversation/message.rs b/crates/goose/src/conversation/message.rs index f432d3292e75..cc7d161dd841 100644 --- a/crates/goose/src/conversation/message.rs +++ b/crates/goose/src/conversation/message.rs @@ -711,6 +711,17 @@ impl Message { } } +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct TokenState { + pub input_tokens: i32, + pub output_tokens: i32, + pub total_tokens: i32, + pub accumulated_input_tokens: i32, + pub accumulated_output_tokens: i32, + pub accumulated_total_tokens: i32, +} + #[cfg(test)] mod tests { use crate::conversation::message::{Message, MessageContent, MessageMetadata}; diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index e55983caf669..ae7de6fbfebe 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -278,11 +278,11 @@ impl Add for Usage { type Output = Self; fn add(self, other: Self) -> Self { - Self { - input_tokens: sum_optionals(self.input_tokens, other.input_tokens), - output_tokens: sum_optionals(self.output_tokens, other.output_tokens), - total_tokens: sum_optionals(self.total_tokens, other.total_tokens), - } + Self::new( + sum_optionals(self.input_tokens, other.input_tokens), + sum_optionals(self.output_tokens, other.output_tokens), + sum_optionals(self.total_tokens, other.total_tokens), + ) } } @@ -298,10 +298,21 @@ impl Usage { output_tokens: Option, total_tokens: Option, ) -> Self { + let calculated_total = if total_tokens.is_none() { + match (input_tokens, output_tokens) { + (Some(input), Some(output)) => Some(input + output), + (Some(input), None) => Some(input), + (None, Some(output)) => Some(output), + (None, None) => None, + } + } else { + total_tokens + }; + Self { input_tokens, output_tokens, - total_tokens, + total_tokens: calculated_total, } } } diff --git a/crates/goose/src/providers/formats/bedrock.rs b/crates/goose/src/providers/formats/bedrock.rs index c1d627a6778d..d6488163000b 100644 --- a/crates/goose/src/providers/formats/bedrock.rs +++ b/crates/goose/src/providers/formats/bedrock.rs @@ -345,11 +345,11 @@ pub fn from_bedrock_role(role: &bedrock::ConversationRole) -> Result { } pub fn from_bedrock_usage(usage: &bedrock::TokenUsage) -> Usage { - Usage { - input_tokens: Some(usage.input_tokens), - output_tokens: Some(usage.output_tokens), - total_tokens: Some(usage.total_tokens), - } + Usage::new( + Some(usage.input_tokens), + Some(usage.output_tokens), + Some(usage.total_tokens), + ) } pub fn from_bedrock_json(document: &Document) -> Result { diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index 5861b09cdd59..e26c08eaf7b8 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -307,11 +307,11 @@ impl Provider for SageMakerTgiProvider { let message = self.parse_tgi_response(response)?; // TGI doesn't provide usage statistics, so we estimate - let usage = Usage { - input_tokens: Some(0), // Would need to tokenize input to get accurate count - output_tokens: Some(0), // Would need to tokenize output to get accurate count - total_tokens: Some(0), - }; + let usage = Usage::new( + Some(0), // Would need to tokenize input to get accurate count + Some(0), // Would need to tokenize output to get accurate count + Some(0), + ); // Add debug trace let debug_payload = serde_json::json!({ diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index 4a699222d4a5..0698fcea175c 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -508,11 +508,11 @@ impl Provider for VeniceProvider { // Extract usage let usage_data = &response_json["usage"]; - let usage = Usage { - input_tokens: usage_data["prompt_tokens"].as_i64().map(|v| v as i32), - output_tokens: usage_data["completion_tokens"].as_i64().map(|v| v as i32), - total_tokens: usage_data["total_tokens"].as_i64().map(|v| v as i32), - }; + let usage = Usage::new( + usage_data["prompt_tokens"].as_i64().map(|v| v as i32), + usage_data["completion_tokens"].as_i64().map(|v| v as i32), + usage_data["total_tokens"].as_i64().map(|v| v as i32), + ); Ok(( Message::new(Role::Assistant, Utc::now().timestamp(), content), diff --git a/openapi.json b/openapi.json deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index 83f3a073db9f..2a613451fc6e 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -3185,12 +3185,16 @@ "type": "object", "required": [ "message", + "token_state", "type" ], "properties": { "message": { "$ref": "#/components/schemas/Message" }, + "token_state": { + "$ref": "#/components/schemas/TokenState" + }, "type": { "type": "string", "enum": [ @@ -4418,6 +4422,43 @@ } } }, + "TokenState": { + "type": "object", + "required": [ + "inputTokens", + "outputTokens", + "totalTokens", + "accumulatedInputTokens", + "accumulatedOutputTokens", + "accumulatedTotalTokens" + ], + "properties": { + "accumulatedInputTokens": { + "type": "integer", + "format": "int32" + }, + "accumulatedOutputTokens": { + "type": "integer", + "format": "int32" + }, + "accumulatedTotalTokens": { + "type": "integer", + "format": "int32" + }, + "inputTokens": { + "type": "integer", + "format": "int32" + }, + "outputTokens": { + "type": "integer", + "format": "int32" + }, + "totalTokens": { + "type": "integer", + "format": "int32" + } + } + }, "Tool": { "type": "object", "required": [ diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index 6c4eb54cee6c..d9b9c1aac0a7 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -362,6 +362,7 @@ export type MessageContent = (TextContent & { export type MessageEvent = { message: Message; + token_state: TokenState; type: 'Message'; } | { error: string; @@ -779,6 +780,15 @@ export type ThinkingContent = { thinking: string; }; +export type TokenState = { + accumulatedInputTokens: number; + accumulatedOutputTokens: number; + accumulatedTotalTokens: number; + inputTokens: number; + outputTokens: number; + totalTokens: number; +}; + export type Tool = { annotations?: ToolAnnotations | { [key: string]: unknown; diff --git a/ui/desktop/src/components/BaseChat.tsx b/ui/desktop/src/components/BaseChat.tsx index c5a96bcdcb3b..67090686db2c 100644 --- a/ui/desktop/src/components/BaseChat.tsx +++ b/ui/desktop/src/components/BaseChat.tsx @@ -132,6 +132,7 @@ function BaseChatContent({ sessionOutputTokens, localInputTokens, localOutputTokens, + tokenState, commandHistory, toolCallNotifications, sessionMetadata, @@ -442,9 +443,13 @@ function BaseChatContent({ commandHistory={commandHistory} initialValue={input || ''} setView={setView} - numTokens={sessionTokenCount} - inputTokens={sessionInputTokens || localInputTokens} - outputTokens={sessionOutputTokens || localOutputTokens} + totalTokens={tokenState?.totalTokens ?? sessionTokenCount} + accumulatedInputTokens={ + tokenState?.accumulatedInputTokens ?? sessionInputTokens ?? localInputTokens + } + accumulatedOutputTokens={ + tokenState?.accumulatedOutputTokens ?? sessionOutputTokens ?? localOutputTokens + } droppedFiles={droppedFiles} onFilesProcessed={() => setDroppedFiles([])} // Clear dropped files after processing messages={messages} diff --git a/ui/desktop/src/components/BaseChat2.tsx b/ui/desktop/src/components/BaseChat2.tsx index bef13f7860a7..8dfb8f5c2356 100644 --- a/ui/desktop/src/components/BaseChat2.tsx +++ b/ui/desktop/src/components/BaseChat2.tsx @@ -72,6 +72,7 @@ function BaseChatContent({ stopStreaming, sessionLoadError, setRecipeUserParams, + tokenState, } = useChatStream({ sessionId, onStreamFinish, @@ -281,9 +282,13 @@ function BaseChatContent({ //commandHistory={commandHistory} initialValue={initialPrompt} setView={setView} - numTokens={session?.total_tokens || undefined} - inputTokens={session?.input_tokens || undefined} - outputTokens={session?.output_tokens || undefined} + totalTokens={tokenState?.totalTokens ?? session?.total_tokens ?? undefined} + accumulatedInputTokens={ + tokenState?.accumulatedInputTokens ?? session?.accumulated_input_tokens ?? undefined + } + accumulatedOutputTokens={ + tokenState?.accumulatedOutputTokens ?? session?.accumulated_output_tokens ?? undefined + } droppedFiles={droppedFiles} onFilesProcessed={() => setDroppedFiles([])} // Clear dropped files after processing messages={messages} diff --git a/ui/desktop/src/components/ChatInput.tsx b/ui/desktop/src/components/ChatInput.tsx index e907b00ecdbd..863c7f65b5b7 100644 --- a/ui/desktop/src/components/ChatInput.tsx +++ b/ui/desktop/src/components/ChatInput.tsx @@ -70,9 +70,9 @@ interface ChatInputProps { droppedFiles?: DroppedFile[]; onFilesProcessed?: () => void; // Callback to clear dropped files after processing setView: (view: View) => void; - numTokens?: number; - inputTokens?: number; - outputTokens?: number; + totalTokens?: number; + accumulatedInputTokens?: number; + accumulatedOutputTokens?: number; messages?: Message[]; sessionCosts?: { [key: string]: { @@ -103,9 +103,9 @@ export default function ChatInput({ droppedFiles = [], onFilesProcessed, setView, - numTokens, - inputTokens, - outputTokens, + totalTokens, + accumulatedInputTokens, + accumulatedOutputTokens, messages = [], disableAnimation = false, sessionCosts, @@ -505,16 +505,16 @@ export default function ChatInput({ clearAlerts(); // Show alert when either there is registered token usage, or we know the limit - if ((numTokens && numTokens > 0) || (isTokenLimitLoaded && tokenLimit)) { + if ((totalTokens && totalTokens > 0) || (isTokenLimitLoaded && tokenLimit)) { addAlert({ type: AlertType.Info, message: 'Context window', progress: { - current: numTokens || 0, + current: totalTokens || 0, total: tokenLimit, }, showCompactButton: true, - compactButtonDisabled: !numTokens, + compactButtonDisabled: !totalTokens, onCompact: () => { window.dispatchEvent(new CustomEvent('hide-alert-popover')); @@ -542,7 +542,7 @@ export default function ChatInput({ } // We intentionally omit setView as it shouldn't trigger a re-render of alerts // eslint-disable-next-line react-hooks/exhaustive-deps - }, [numTokens, toolCount, tokenLimit, isTokenLimitLoaded, addAlert, clearAlerts]); + }, [totalTokens, toolCount, tokenLimit, isTokenLimitLoaded, addAlert, clearAlerts]); // Cleanup effect for component unmount - prevent memory leaks useEffect(() => { @@ -1540,8 +1540,8 @@ export default function ChatInput({ <>
diff --git a/ui/desktop/src/components/hub.tsx b/ui/desktop/src/components/hub.tsx index 4ae2e64978d9..3cae47ba04dc 100644 --- a/ui/desktop/src/components/hub.tsx +++ b/ui/desktop/src/components/hub.tsx @@ -78,9 +78,9 @@ export default function Hub({ commandHistory={[]} initialValue="" setView={setView} - numTokens={0} - inputTokens={0} - outputTokens={0} + totalTokens={0} + accumulatedInputTokens={0} + accumulatedOutputTokens={0} droppedFiles={[]} onFilesProcessed={() => {}} messages={[]} diff --git a/ui/desktop/src/hooks/useChatEngine.ts b/ui/desktop/src/hooks/useChatEngine.ts index 1c0f6b7e47a4..7c60038e6316 100644 --- a/ui/desktop/src/hooks/useChatEngine.ts +++ b/ui/desktop/src/hooks/useChatEngine.ts @@ -77,6 +77,7 @@ export const useChatEngine = ({ notifications, session, setError, + tokenState, } = useMessageStream({ api: getApiUrl('/reply'), id: chat.sessionId, @@ -451,6 +452,7 @@ export const useChatEngine = ({ sessionOutputTokens, localInputTokens, localOutputTokens, + tokenState, // UI helpers commandHistory, diff --git a/ui/desktop/src/hooks/useChatStream.ts b/ui/desktop/src/hooks/useChatStream.ts index 2b4c9f2b4e02..64ed1912de43 100644 --- a/ui/desktop/src/hooks/useChatStream.ts +++ b/ui/desktop/src/hooks/useChatStream.ts @@ -7,6 +7,7 @@ import { reply, resumeAgent, Session, + TokenState, updateFromSession, updateSessionUserRecipeValues, } from '../api'; @@ -60,6 +61,7 @@ interface UseChatStreamReturn { setRecipeUserParams: (values: Record) => Promise; stopStreaming: () => void; sessionLoadError?: string; + tokenState: TokenState; } function pushMessage(currentMessages: Message[], incomingMsg: Message): Message[] { @@ -88,6 +90,7 @@ async function streamFromResponse( stream: AsyncIterable, initialMessages: Message[], updateMessages: (messages: Message[]) => void, + updateTokenState: (tokenState: TokenState) => void, updateChatState: (state: ChatState) => void, onFinish: (error?: string) => void ): Promise { @@ -119,6 +122,8 @@ async function streamFromResponse( }); } + updateTokenState(event.token_state); + updateMessages(currentMessages); break; } @@ -171,6 +176,14 @@ export function useChatStream({ const [session, setSession] = useState(); const [sessionLoadError, setSessionLoadError] = useState(); const [chatState, setChatState] = useState(ChatState.Idle); + const [tokenState, setTokenState] = useState({ + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + accumulatedInputTokens: 0, + accumulatedOutputTokens: 0, + accumulatedTotalTokens: 0, + }); const abortControllerRef = useRef(null); useEffect(() => { @@ -288,6 +301,7 @@ export function useChatStream({ stream, currentMessages, (messages: Message[]) => setMessagesAndLog(messages, 'streaming'), + setTokenState, setChatState, onFinish ); @@ -373,5 +387,6 @@ export function useChatStream({ handleSubmit, stopStreaming, setRecipeUserParams, + tokenState, }; } diff --git a/ui/desktop/src/hooks/useMessageStream.ts b/ui/desktop/src/hooks/useMessageStream.ts index d05aaadbe067..081d264a7a54 100644 --- a/ui/desktop/src/hooks/useMessageStream.ts +++ b/ui/desktop/src/hooks/useMessageStream.ts @@ -6,7 +6,7 @@ import { getCompactingMessage, hasCompletedToolCalls, } from '../types/message'; -import { Conversation, Message, Role } from '../api'; +import { Conversation, Message, Role, TokenState } from '../api'; import { getSession, Session } from '../api'; import { ChatState } from '../types/chatState'; @@ -35,7 +35,7 @@ export interface NotificationEvent { // Event types for SSE stream type MessageEvent = - | { type: 'Message'; message: Message } + | { type: 'Message'; message: Message; token_state: TokenState } | { type: 'Error'; error: string } | { type: 'Finish'; reason: string } | { type: 'ModelChange'; model: string; mode: string } @@ -165,6 +165,9 @@ export interface UseMessageStreamHelpers { /** Clear error state */ setError: (error: Error | undefined) => void; + + /** Real-time token state from server */ + tokenState: TokenState; } /** @@ -197,6 +200,14 @@ export function useMessageStream({ null ); const [session, setSession] = useState(null); + const [tokenState, setTokenState] = useState({ + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + accumulatedInputTokens: 0, + accumulatedOutputTokens: 0, + accumulatedTotalTokens: 0, + }); // expose a way to update the body so we can update the session id when CLE occurs const updateMessageStreamBody = useCallback((newBody: object) => { @@ -280,6 +291,8 @@ export function useMessageStream({ // Transition from waiting to streaming on first message mutateChatState(ChatState.Streaming); + setTokenState(parsedEvent.token_state); + // Create a new message object with the properties preserved or defaulted const newMessage: Message = { ...parsedEvent.message, @@ -341,7 +354,6 @@ export function useMessageStream({ } case 'UpdateConversation': { - currentMessages = parsedEvent.conversation; setMessages(parsedEvent.conversation); break; } @@ -650,5 +662,6 @@ export function useMessageStream({ currentModelInfo, session, setError, + tokenState, }; }