diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 6c5425d9a04e..5a340fa3d947 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -133,6 +133,7 @@ pub enum MessageEvent { }, Finish { reason: String, + token_state: TokenState, }, ModelChange { model: String, @@ -149,6 +150,27 @@ pub enum MessageEvent { Ping, } +async fn get_token_state(session_id: &str) -> TokenState { + SessionManager::get_session(session_id, false) + .await + .map(|session| TokenState { + input_tokens: session.input_tokens.unwrap_or(0), + output_tokens: session.output_tokens.unwrap_or(0), + total_tokens: session.total_tokens.unwrap_or(0), + accumulated_input_tokens: session.accumulated_input_tokens.unwrap_or(0), + accumulated_output_tokens: session.accumulated_output_tokens.unwrap_or(0), + accumulated_total_tokens: session.accumulated_total_tokens.unwrap_or(0), + }) + .inspect_err(|e| { + tracing::warn!( + "Failed to fetch session token state for {}: {}", + session_id, + e + ); + }) + .unwrap_or_default() +} + async fn stream_event( event: MessageEvent, tx: &mpsc::Sender, @@ -321,29 +343,7 @@ pub async fn reply( all_messages.push(message.clone()); - let token_state = match SessionManager::get_session(&session_id, false).await { - Ok(session) => { - TokenState { - input_tokens: session.input_tokens.unwrap_or(0), - output_tokens: session.output_tokens.unwrap_or(0), - total_tokens: session.total_tokens.unwrap_or(0), - accumulated_input_tokens: session.accumulated_input_tokens.unwrap_or(0), - accumulated_output_tokens: session.accumulated_output_tokens.unwrap_or(0), - accumulated_total_tokens: session.accumulated_total_tokens.unwrap_or(0), - } - }, - Err(e) => { - tracing::warn!("Failed to fetch session for token state: {}", e); - TokenState { - input_tokens: 0, - output_tokens: 0, - total_tokens: 0, - accumulated_input_tokens: 0, - accumulated_output_tokens: 0, - accumulated_total_tokens: 0, - } - } - }; + let token_state = get_token_state(&session_id).await; stream_event(MessageEvent::Message { message, token_state }, &tx, &cancel_token).await; } @@ -437,9 +437,12 @@ pub async fn reply( ); } + let final_token_state = get_token_state(&session_id).await; + let _ = stream_event( MessageEvent::Finish { reason: "stop".to_string(), + token_state: final_token_state, }, &task_tx, &cancel_token, diff --git a/crates/goose/src/conversation/message.rs b/crates/goose/src/conversation/message.rs index cc7d161dd841..2f18d038836f 100644 --- a/crates/goose/src/conversation/message.rs +++ b/crates/goose/src/conversation/message.rs @@ -711,7 +711,7 @@ impl Message { } } -#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[derive(Debug, Clone, Default, Serialize, Deserialize, ToSchema)] #[serde(rename_all = "camelCase")] pub struct TokenState { pub input_tokens: i32, diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index d7412d5bf626..7622534994d9 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -3321,12 +3321,16 @@ "type": "object", "required": [ "reason", + "token_state", "type" ], "properties": { "reason": { "type": "string" }, + "token_state": { + "$ref": "#/components/schemas/TokenState" + }, "type": { "type": "string", "enum": [ diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index 6468fd3ee3cb..098489445a7b 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -374,6 +374,7 @@ export type MessageEvent = { type: 'Error'; } | { reason: string; + token_state: TokenState; type: 'Finish'; } | { mode: string; diff --git a/ui/desktop/src/components/BaseChat.tsx b/ui/desktop/src/components/BaseChat.tsx index 67090686db2c..fa96740f064c 100644 --- a/ui/desktop/src/components/BaseChat.tsx +++ b/ui/desktop/src/components/BaseChat.tsx @@ -443,12 +443,12 @@ function BaseChatContent({ commandHistory={commandHistory} initialValue={input || ''} setView={setView} - totalTokens={tokenState?.totalTokens ?? sessionTokenCount} + totalTokens={tokenState?.totalTokens || sessionTokenCount} accumulatedInputTokens={ - tokenState?.accumulatedInputTokens ?? sessionInputTokens ?? localInputTokens + tokenState?.accumulatedInputTokens || sessionInputTokens || localInputTokens } accumulatedOutputTokens={ - tokenState?.accumulatedOutputTokens ?? sessionOutputTokens ?? localOutputTokens + tokenState?.accumulatedOutputTokens || sessionOutputTokens || localOutputTokens } droppedFiles={droppedFiles} onFilesProcessed={() => setDroppedFiles([])} // Clear dropped files after processing diff --git a/ui/desktop/src/hooks/useChatEngine.ts b/ui/desktop/src/hooks/useChatEngine.ts index 7c60038e6316..19a333aefda9 100644 --- a/ui/desktop/src/hooks/useChatEngine.ts +++ b/ui/desktop/src/hooks/useChatEngine.ts @@ -213,7 +213,6 @@ export const useChatEngine = ({ // Update token counts when session changes from the message stream useEffect(() => { - console.log('Session received:', session); if (session) { setSessionTokenCount(session.total_tokens || 0); setSessionInputTokens(session.accumulated_input_tokens || 0); diff --git a/ui/desktop/src/hooks/useMessageStream.ts b/ui/desktop/src/hooks/useMessageStream.ts index 2eb98835935b..4e36c7d9ec76 100644 --- a/ui/desktop/src/hooks/useMessageStream.ts +++ b/ui/desktop/src/hooks/useMessageStream.ts @@ -37,7 +37,7 @@ export interface NotificationEvent { type MessageEvent = | { type: 'Message'; message: Message; token_state: TokenState } | { type: 'Error'; error: string } - | { type: 'Finish'; reason: string } + | { type: 'Finish'; reason: string; token_state: TokenState } | { type: 'ModelChange'; model: string; mode: string } | { type: 'UpdateConversation'; conversation: Conversation } | NotificationEvent; @@ -368,6 +368,8 @@ export function useMessageStream({ } case 'Finish': { + setTokenState(parsedEvent.token_state); + if (onFinish && currentMessages.length > 0) { const lastMessage = currentMessages[currentMessages.length - 1]; onFinish(lastMessage, parsedEvent.reason);