Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/goose-cli/src/commands/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,10 @@ async fn process_message_streaming(
// For now, we'll just log them
tracing::info!("Received MCP notification in web interface");
}
Ok(AgentEvent::ModelChange { model, mode }) => {
// Log model change
tracing::info!("Model changed to {} in {} mode", model, mode);
}
Err(e) => {
error!("Error in message stream: {}", e);
let mut sender = sender.lock().await;
Expand Down
6 changes: 6 additions & 0 deletions crates/goose-cli/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,12 @@ impl Session {
}
}
}
Some(Ok(AgentEvent::ModelChange { model, mode })) => {
// Log model change if in debug mode
if self.debug {
eprintln!("Model changed to {} in {} mode", model, mode);
}
}
Some(Err(e)) => {
eprintln!("Error: {}", e);
drop(stream);
Expand Down
3 changes: 3 additions & 0 deletions crates/goose-ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ pub unsafe extern "C" fn goose_agent_send_message(
Ok(AgentEvent::McpNotification(_)) => {
// TODO: Handle MCP notifications.
}
Ok(AgentEvent::ModelChange { .. }) => {
// Model change events are informational, just continue
}
Err(e) => {
full_response.push_str(&format!("\nError in message stream: {}", e));
}
Expand Down
3 changes: 3 additions & 0 deletions crates/goose-scheduler-executor/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ async fn execute_recipe(job_id: &str, recipe_path: &str) -> Result<String> {
Ok(AgentEvent::McpNotification(_)) => {
// Handle notifications if needed
}
Ok(AgentEvent::ModelChange { .. }) => {
// Model change events are informational, just continue
}
Err(e) => {
return Err(anyhow!("Error receiving message from agent: {}", e));
}
Expand Down
21 changes: 21 additions & 0 deletions crates/goose-server/src/routes/config_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,26 @@ pub async fn backup_config(
}
}

#[utoipa::path(
get,
path = "/config/current-model",
responses(
(status = 200, description = "Current model retrieved successfully", body = String),
)
)]
pub async fn get_current_model(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
) -> Result<Json<Value>, StatusCode> {
verify_secret_key(&headers, &state)?;

let current_model = goose::providers::base::get_current_model();

Ok(Json(serde_json::json!({
"model": current_model
})))
}

pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/config", get(read_all_config))
Expand All @@ -454,6 +474,7 @@ pub fn routes(state: Arc<AppState>) -> Router {
.route("/config/init", post(init_config))
.route("/config/backup", post(backup_config))
.route("/config/permissions", post(upsert_permissions))
.route("/config/current-model", get(get_current_model))
.with_state(state)
}

Expand Down
19 changes: 19 additions & 0 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ enum MessageEvent {
Finish {
reason: String,
},
ModelChange {
model: String,
mode: String,
},
Notification {
request_id: String,
message: JsonRpcMessage,
Expand Down Expand Up @@ -233,6 +237,17 @@ async fn handler(
}
});
}
Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => {
if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await {
tracing::error!("Error sending model change through channel: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
}
}
Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => {
if let Err(e) = stream_event(MessageEvent::Notification{
request_id: request_id.clone(),
Expand Down Expand Up @@ -352,6 +367,10 @@ async fn ask_handler(
}
}
}
Ok(AgentEvent::ModelChange { model, mode }) => {
// Log model change for non-streaming
tracing::info!("Model changed to {} in {} mode", model, mode);
}
Ok(AgentEvent::McpNotification(n)) => {
// Handle notifications if needed
tracing::info!("Received notification: {:?}", n);
Expand Down
21 changes: 21 additions & 0 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pub struct Agent {
pub enum AgentEvent {
Message(Message),
McpNotification((String, JsonRpcMessage)),
ModelChange { model: String, mode: String },
}

impl Agent {
Expand Down Expand Up @@ -582,6 +583,26 @@ impl Agent {
&toolshim_tools,
).await {
Ok((response, usage)) => {
// Emit model change event if provider is lead-worker
let provider = self.provider().await?;
if let Some(lead_worker) = provider.as_lead_worker() {
// The actual model used is in the usage
let active_model = usage.model.clone();
let (lead_model, worker_model) = lead_worker.get_model_info();
let mode = if active_model == lead_model {
"lead"
} else if active_model == worker_model {
"worker"
} else {
"unknown"
};

yield AgentEvent::ModelChange {
model: active_model,
mode: mode.to_string(),
};
}

// record usage for the session in the session file
if let Some(session_config) = session.clone() {
Self::update_session_metrics(session_config, &usage, messages.len()).await?;
Expand Down
14 changes: 14 additions & 0 deletions crates/goose/src/providers/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ use async_trait::async_trait;
pub trait LeadWorkerProviderTrait {
/// Get information about the lead and worker models for logging
fn get_model_info(&self) -> (String, String);

/// Get the currently active model name
fn get_active_model(&self) -> String;
}

/// Base trait for AI providers (OpenAI, Anthropic, etc)
Expand Down Expand Up @@ -207,6 +210,17 @@ pub trait Provider: Send + Sync {
fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
None
}

/// Get the currently active model name
/// For regular providers, this returns the configured model
/// For LeadWorkerProvider, this returns the currently active model (lead or worker)
fn get_active_model_name(&self) -> String {
if let Some(lead_worker) = self.as_lead_worker() {
lead_worker.get_active_model()
} else {
self.get_model_config().model_name
}
}
}

#[cfg(test)]
Expand Down
30 changes: 26 additions & 4 deletions crates/goose/src/providers/lead_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,16 @@ impl LeadWorkerProviderTrait for LeadWorkerProvider {
let worker_model = self.worker_provider.get_model_config().model_name;
(lead_model, worker_model)
}

/// Get the currently active model name
fn get_active_model(&self) -> String {
// Read from the global store which was set during complete()
use super::base::get_current_model;
get_current_model().unwrap_or_else(|| {
// Fallback to lead model if no current model is set
self.lead_provider.get_model_config().model_name
})
}
}

#[async_trait]
Expand Down Expand Up @@ -336,19 +346,31 @@ impl Provider for LeadWorkerProvider {
"worker"
};

// Get the active model name and update the global store
let active_model_name = if turn_count < self.lead_turns || in_fallback {
self.lead_provider.get_model_config().model_name.clone()
} else {
self.worker_provider.get_model_config().model_name.clone()
};

// Update the global current model store
super::base::set_current_model(&active_model_name);

if in_fallback {
tracing::info!(
"🔄 Using {} provider for turn {} (FALLBACK MODE: {} turns remaining)",
"🔄 Using {} provider for turn {} (FALLBACK MODE: {} turns remaining) - Model: {}",
provider_type,
turn_count + 1,
fallback_remaining
fallback_remaining,
active_model_name
);
} else {
tracing::info!(
"Using {} provider for turn {} (lead_turns: {})",
"Using {} provider for turn {} (lead_turns: {}) - Model: {}",
provider_type,
turn_count + 1,
self.lead_turns
self.lead_turns,
active_model_name
);
}

Expand Down
3 changes: 3 additions & 0 deletions crates/goose/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,9 @@ async fn run_scheduled_job_internal(
Ok(AgentEvent::McpNotification(_)) => {
// Handle notifications if needed
}
Ok(AgentEvent::ModelChange { .. }) => {
// Model change events are informational, just continue
}
Err(e) => {
tracing::error!(
"[Job {}] Error receiving message from agent: {}",
Expand Down
3 changes: 3 additions & 0 deletions crates/goose/tests/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ async fn run_truncate_test(
Ok(AgentEvent::McpNotification(n)) => {
println!("MCP Notification: {n:?}");
}
Ok(AgentEvent::ModelChange { .. }) => {
// Model change events are informational, just continue
}
Err(e) => {
println!("Error: {:?}", e);
return Err(e);
Expand Down
11 changes: 9 additions & 2 deletions ui/desktop/src/components/ChatView.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import React, { useEffect, useRef, useState, useMemo, useCallback } from 'react';
import React, { useEffect, useRef, useState, useMemo, useCallback, createContext, useContext } from 'react';
import { getApiUrl } from '../config';
import FlappyGoose from './FlappyGoose';
import GooseMessage from './GooseMessage';
Expand Down Expand Up @@ -37,6 +37,10 @@ import {
TextContent,
} from '../types/message';

// Context for sharing current model info
const CurrentModelContext = createContext<{ model: string; mode: string } | null>(null);
export const useCurrentModelInfo = () => useContext(CurrentModelContext);

export interface ChatType {
id: string;
title: string;
Expand Down Expand Up @@ -144,6 +148,7 @@ function ChatContent({
handleSubmit: _submitMessage,
updateMessageStreamBody,
notifications,
currentModelInfo,
} = useMessageStream({
api: getApiUrl('/reply'),
initialMessages: chat.messages,
Expand Down Expand Up @@ -504,7 +509,8 @@ function ChatContent({
}, new Map());

return (
<div className="flex flex-col w-full h-screen items-center justify-center">
<CurrentModelContext.Provider value={currentModelInfo}>
<div className="flex flex-col w-full h-screen items-center justify-center">
{/* Loader when generating recipe */}
{isGeneratingRecipe && <LayingEggLoader />}
<MoreMenuLayout
Expand Down Expand Up @@ -647,5 +653,6 @@ function ChatContent({
summaryContent={summaryContent}
/>
</div>
</CurrentModelContext.Provider>
);
}
Loading
Loading