diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index 9c71d1286aa1..8ad2b0422352 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -648,18 +648,8 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { process::exit(1); } }; - let provider_for_display = Arc::clone(&new_provider); - - if let Some(lead_worker) = new_provider.as_lead_worker() { - let (lead_model, worker_model) = lead_worker.get_model_info(); - tracing::info!( - "🤖 Lead/Worker Mode Enabled: Lead model (first 3 turns): {}, Worker model (turn 4+): {}, Auto-fallback on failures: Enabled", - lead_model, - worker_model - ); - } else { - tracing::info!("🤖 Using model: {}", resolved.model_name); - } + let provider_for_debug = Arc::clone(&new_provider); + tracing::info!("🤖 Using model: {}", resolved.model_name); agent .update_provider(new_provider, &session_id) @@ -692,7 +682,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { let agent_ptr = resolve_and_load_extensions( agent, extensions_for_provider, - Arc::clone(&provider_for_display), + Arc::clone(&provider_for_debug), session_config.interactive, &session_id, ) @@ -732,7 +722,6 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { &resolved.provider_name, &resolved.model_name, &Some(session_id), - Some(&provider_for_display), ); } session diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 416adec726ea..bf655a698483 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -79,10 +79,6 @@ enum StreamEvent { #[serde(flatten)] data: NotificationData, }, - ModelChange { - model: String, - mode: String, - }, Error { error: String, }, @@ -1070,13 +1066,6 @@ impl CliSession { Some(Ok(AgentEvent::HistoryReplaced(updated_conversation))) => { self.messages = updated_conversation; } - Some(Ok(AgentEvent::ModelChange { model, mode })) => { - if is_stream_json_mode { - emit_stream_event(&StreamEvent::ModelChange { model: model.clone(), mode: mode.clone() }); - } else if self.debug { - eprintln!("Model changed to {} in {} mode", model, mode); - } - } Some(Err(e)) => { handle_agent_error(&e, is_stream_json_mode); cancel_token_clone.cancel(); diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index 8472cc8a6af2..ea3f0678d4ab 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -17,7 +17,7 @@ use std::cell::RefCell; use std::collections::HashMap; use std::io::{Error, IsTerminal, Write}; use std::path::Path; -use std::sync::{Arc, LazyLock}; +use std::sync::LazyLock; use std::time::Duration; use super::streaming_buffer::MarkdownBuffer; @@ -1251,7 +1251,6 @@ pub fn display_session_info( provider: &str, model: &str, session_id: &Option, - provider_instance: Option<&Arc>, ) { set_terminal_title(); @@ -1263,16 +1262,7 @@ pub fn display_session_info( "new session" }; - let model_display = if let Some(provider_inst) = provider_instance { - if let Some(lead_worker) = provider_inst.as_lead_worker() { - let (lead_model, worker_model) = lead_worker.get_model_info(); - format!("{} → {}", lead_model, worker_model) - } else { - model.to_string() - } - } else { - model.to_string() - }; + let model_display = model.to_string(); let cwd_display = std::env::current_dir() .ok() diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 7df27b434c43..09fc5b8b932b 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -137,10 +137,6 @@ pub enum MessageEvent { reason: String, token_state: TokenState, }, - ModelChange { - model: String, - mode: String, - }, Notification { request_id: String, #[schema(value_type = Object)] @@ -364,9 +360,6 @@ pub async fn reply( stream_event(MessageEvent::UpdateConversation {conversation: new_messages}, &tx, &cancel_token).await; } - Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => { - stream_event(MessageEvent::ModelChange { model, mode }, &tx, &cancel_token).await; - } Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => { stream_event(MessageEvent::Notification{ request_id: request_id.clone(), diff --git a/crates/goose-server/src/routes/session_events.rs b/crates/goose-server/src/routes/session_events.rs index 7f69aa97ed69..065bd1f3c24b 100644 --- a/crates/goose-server/src/routes/session_events.rs +++ b/crates/goose-server/src/routes/session_events.rs @@ -462,13 +462,6 @@ pub async fn session_reply( ) .await; } - Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => { - publish( - Some(task_request_id.clone()), - MessageEvent::ModelChange { model, mode }, - ) - .await; - } Ok(Some(Ok(AgentEvent::McpNotification((notification_request_id, n))))) => { publish( Some(task_request_id.clone()), diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index da4f9b00e15e..3f3295016695 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -158,7 +158,6 @@ pub struct Agent { pub enum AgentEvent { Message(Message), McpNotification((String, ServerNotification)), - ModelChange { model: String, mode: String }, HistoryReplaced(Conversation), } @@ -1230,27 +1229,6 @@ impl Agent { Ok((response, usage)) => { compaction_attempts = 0; - // Emit model change event if provider is lead-worker - let provider = self.provider().await?; - if let Some(lead_worker) = provider.as_lead_worker() { - if let Some(ref usage) = usage { - let active_model = usage.model.clone(); - let (lead_model, worker_model) = lead_worker.get_model_info(); - let mode = if active_model == lead_model { - "lead" - } else if active_model == worker_model { - "worker" - } else { - "unknown" - }; - - yield AgentEvent::ModelChange { - model: active_model, - mode: mode.to_string(), - }; - } - } - if let Some(ref usage) = usage { self.update_session_metrics(&session_config.id, session_config.schedule_id.clone(), usage, false).await?; } diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 0fc3cf6c86d9..e40f649ea7d7 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -208,8 +208,6 @@ impl Agent { Ok((tools, toolshim_tools, system_prompt)) } - // Don't add gen_ai.request.model here — provider.get_model_config() - // returns the wrong model for LeadWorkerProvider. #[tracing::instrument( skip(provider, session_id, system_prompt, messages, tools, toolshim_tools), fields(session.id = %session_id) diff --git a/crates/goose/src/agents/subagent_handler.rs b/crates/goose/src/agents/subagent_handler.rs index e1ca3fc888fe..ade7d0b226d1 100644 --- a/crates/goose/src/agents/subagent_handler.rs +++ b/crates/goose/src/agents/subagent_handler.rs @@ -208,7 +208,7 @@ fn get_agent_messages(params: SubagentRunParams) -> AgentMessagesFuture { } conversation.push(msg); } - Ok(AgentEvent::McpNotification(_)) | Ok(AgentEvent::ModelChange { .. }) => {} + Ok(AgentEvent::McpNotification(_)) => {} Ok(AgentEvent::HistoryReplaced(updated_conversation)) => { conversation = updated_conversation; } diff --git a/crates/goose/src/gateway/handler.rs b/crates/goose/src/gateway/handler.rs index 3825337470ec..0e682fece160 100644 --- a/crates/goose/src/gateway/handler.rs +++ b/crates/goose/src/gateway/handler.rs @@ -445,17 +445,6 @@ impl GatewayHandler { "gateway stream: mcp notification #{event_count}" ); } - Ok(AgentEvent::ModelChange { - ref model, - ref mode, - }) => { - tracing::debug!( - session_id, - model, - mode, - "gateway stream: model change #{event_count}" - ); - } Ok(AgentEvent::HistoryReplaced(_)) => { tracing::debug!( session_id, diff --git a/crates/goose/src/posthog.rs b/crates/goose/src/posthog.rs index b34945c16862..6f54c26e7ba2 100644 --- a/crates/goose/src/posthog.rs +++ b/crates/goose/src/posthog.rs @@ -421,30 +421,6 @@ async fn send_session_event(installation: &InstallationData) -> Result<(), Strin insert(&mut props, "setting_max_turns", max_turns); } - if let Ok(lead_model) = config.get_param::("GOOSE_LEAD_MODEL") { - insert(&mut props, "setting_lead_model", lead_model); - } - if let Ok(lead_provider) = config.get_param::("GOOSE_LEAD_PROVIDER") { - insert(&mut props, "setting_lead_provider", lead_provider); - } - if let Ok(lead_turns) = config.get_param::("GOOSE_LEAD_TURNS") { - insert(&mut props, "setting_lead_turns", lead_turns); - } - if let Ok(lead_failure_threshold) = config.get_param::("GOOSE_LEAD_FAILURE_THRESHOLD") { - insert( - &mut props, - "setting_lead_failure_threshold", - lead_failure_threshold, - ); - } - if let Ok(lead_fallback_turns) = config.get_param::("GOOSE_LEAD_FALLBACK_TURNS") { - insert( - &mut props, - "setting_lead_fallback_turns", - lead_fallback_turns, - ); - } - let extensions = get_enabled_extensions(); insert(&mut props, "extensions_count", extensions.len() as u64); let extension_names: Vec = extensions.iter().map(|e| e.name()).collect(); diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index 46fd49bc8142..fae0dd8b7ae1 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -450,18 +450,6 @@ pub enum PermissionRouting { Noop, } -/// Trait for LeadWorkerProvider-specific functionality -pub trait LeadWorkerProviderTrait { - /// Get information about the lead and worker models for logging - fn get_model_info(&self) -> (String, String); - - /// Get the currently active model name - fn get_active_model(&self) -> String; - - /// Get (lead_turns, failure_threshold, fallback_turns) - fn get_settings(&self) -> (usize, usize, usize); -} - /// Base trait for AI providers (OpenAI, Anthropic, etc) #[async_trait] pub trait Provider: Send + Sync { @@ -645,23 +633,6 @@ pub trait Provider: Send + Sync { )) } - /// Check if this provider is a LeadWorkerProvider - /// This is used for logging model information at startup - fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> { - None - } - - /// Get the currently active model name - /// For regular providers, this returns the configured model - /// For LeadWorkerProvider, this returns the currently active model (lead or worker) - fn get_active_model_name(&self) -> String { - if let Some(lead_worker) = self.as_lead_worker() { - lead_worker.get_active_model() - } else { - self.get_model_config().model_name - } - } - /// Returns the first 3 user messages as strings for session naming fn get_initial_user_messages(&self, messages: &Conversation) -> Vec { messages diff --git a/crates/goose/src/providers/init.rs b/crates/goose/src/providers/init.rs index 03cef75f5e5b..1d00a9c9e09c 100644 --- a/crates/goose/src/providers/init.rs +++ b/crates/goose/src/providers/init.rs @@ -17,7 +17,6 @@ use super::{ gemini_cli::GeminiCliProvider, githubcopilot::GithubCopilotProvider, google::GoogleProvider, - lead_worker::LeadWorkerProvider, litellm::LiteLLMProvider, local_inference::LocalInferenceProvider, nanogpt::NanoGptProvider, @@ -41,10 +40,6 @@ use crate::{ use anyhow::Result; use tokio::sync::OnceCell; -const DEFAULT_LEAD_TURNS: usize = 3; -const DEFAULT_FAILURE_THRESHOLD: usize = 2; -const DEFAULT_FALLBACK_TURNS: usize = 2; - static REGISTRY: OnceCell> = OnceCell::const_new(); async fn init_registry() -> RwLock { @@ -135,13 +130,6 @@ pub async fn create( model: ModelConfig, extensions: Vec, ) -> Result> { - let config = crate::config::Config::global(); - - if let Ok(lead_model_name) = config.get_param::("GOOSE_LEAD_MODEL") { - tracing::info!("Creating lead/worker provider from environment variables"); - return create_lead_worker_from_env(name, &model, &lead_model_name, extensions).await; - } - let constructor = get_from_registry(name).await?.constructor.clone(); constructor(model, extensions).await } @@ -179,179 +167,10 @@ pub async fn create_with_named_model( create(provider_name, config, extensions).await } -async fn create_lead_worker_from_env( - default_provider_name: &str, - default_model: &ModelConfig, - lead_model_name: &str, - extensions: Vec, -) -> Result> { - let config = crate::config::Config::global(); - - let lead_provider_name = config - .get_param::("GOOSE_LEAD_PROVIDER") - .unwrap_or_else(|_| default_provider_name.to_string()); - - let lead_turns = config - .get_param::("GOOSE_LEAD_TURNS") - .unwrap_or(DEFAULT_LEAD_TURNS); - let failure_threshold = config - .get_param::("GOOSE_LEAD_FAILURE_THRESHOLD") - .unwrap_or(DEFAULT_FAILURE_THRESHOLD); - let fallback_turns = config - .get_param::("GOOSE_LEAD_FALLBACK_TURNS") - .unwrap_or(DEFAULT_FALLBACK_TURNS); - - let lead_model_config = ModelConfig::new_with_context_env( - lead_model_name.to_string(), - &lead_provider_name, - Some("GOOSE_LEAD_CONTEXT_LIMIT"), - )?; - - let worker_model_config = create_worker_model_config(default_model, default_provider_name)?; - - let registry = get_registry().await; - - let lead_constructor = { - let guard = registry.read().unwrap(); - guard - .entries - .get(&lead_provider_name) - .ok_or_else(|| anyhow::anyhow!("Unknown provider: {}", lead_provider_name))? - .constructor - .clone() - }; - - let worker_constructor = { - let guard = registry.read().unwrap(); - guard - .entries - .get(default_provider_name) - .ok_or_else(|| anyhow::anyhow!("Unknown provider: {}", default_provider_name))? - .constructor - .clone() - }; - - let lead_provider = lead_constructor(lead_model_config, extensions.clone()).await?; - let worker_provider = worker_constructor(worker_model_config, extensions).await?; - - Ok(Arc::new(LeadWorkerProvider::new_with_settings( - lead_provider, - worker_provider, - lead_turns, - failure_threshold, - fallback_turns, - ))) -} - -fn create_worker_model_config( - default_model: &ModelConfig, - provider_name: &str, -) -> Result { - let mut worker_config = ModelConfig::new_or_fail(&default_model.model_name) - .with_canonical_limits(provider_name) - .with_context_limit(default_model.context_limit) - .with_temperature(default_model.temperature) - .with_max_tokens(default_model.max_tokens) - .with_toolshim(default_model.toolshim) - .with_toolshim_model(default_model.toolshim_model.clone()); - - let global_config = crate::config::Config::global(); - - if let Ok(limit) = global_config.get_param::("GOOSE_WORKER_CONTEXT_LIMIT") { - worker_config = worker_config.with_context_limit(Some(limit)); - } else if let Ok(limit) = global_config.get_param::("GOOSE_CONTEXT_LIMIT") { - worker_config = worker_config.with_context_limit(Some(limit)); - } - - Ok(worker_config) -} - #[cfg(test)] mod tests { use super::*; - #[test_case::test_case(None, None, None, DEFAULT_LEAD_TURNS, DEFAULT_FAILURE_THRESHOLD, DEFAULT_FALLBACK_TURNS ; "defaults")] - #[test_case::test_case(Some("7"), Some("4"), Some("3"), 7, 4, 3 ; "custom")] - #[tokio::test] - async fn test_create_lead_worker_provider( - lead_turns: Option<&str>, - failure_threshold: Option<&str>, - fallback_turns: Option<&str>, - expected_turns: usize, - expected_failure: usize, - expected_fallback: usize, - ) { - let _guard = env_lock::lock_env([ - ("GOOSE_LEAD_MODEL", Some("gpt-4o")), - ("GOOSE_LEAD_PROVIDER", None), - ("GOOSE_LEAD_TURNS", lead_turns), - ("GOOSE_LEAD_FAILURE_THRESHOLD", failure_threshold), - ("GOOSE_LEAD_FALLBACK_TURNS", fallback_turns), - ("OPENAI_API_KEY", Some("fake-openai-no-keyring")), - ("OPENAI_CUSTOM_HEADERS", Some("")), - ]); - - let provider = create( - "openai", - ModelConfig::new_or_fail("gpt-4o-mini").with_canonical_limits("openai"), - Vec::new(), - ) - .await - .unwrap(); - let lw = provider.as_lead_worker().unwrap(); - let (lead, worker) = lw.get_model_info(); - assert_eq!(lead, "gpt-4o"); - assert_eq!(worker, "gpt-4o-mini"); - assert_eq!( - lw.get_settings(), - (expected_turns, expected_failure, expected_fallback) - ); - } - - #[tokio::test] - async fn test_create_regular_provider_without_lead_config() { - let _guard = env_lock::lock_env([ - ("GOOSE_LEAD_MODEL", None), - ("GOOSE_LEAD_PROVIDER", None), - ("GOOSE_LEAD_TURNS", None), - ("GOOSE_LEAD_FAILURE_THRESHOLD", None), - ("GOOSE_LEAD_FALLBACK_TURNS", None), - ("OPENAI_API_KEY", Some("fake-openai-no-keyring")), - ("OPENAI_CUSTOM_HEADERS", Some("")), - ]); - - let provider = create( - "openai", - ModelConfig::new_or_fail("gpt-4o-mini").with_canonical_limits("openai"), - Vec::new(), - ) - .await - .unwrap(); - assert!(provider.as_lead_worker().is_none()); - assert_eq!(provider.get_model_config().model_name, "gpt-4o-mini"); - } - - #[test_case::test_case(None, None, 16_000 ; "no overrides uses default")] - #[test_case::test_case(Some("32000"), None, 32_000 ; "worker limit overrides default")] - #[test_case::test_case(Some("32000"), Some("64000"), 32_000 ; "worker limit takes priority over global")] - fn test_worker_model_context_limit( - worker_limit: Option<&str>, - global_limit: Option<&str>, - expected_limit: usize, - ) { - let _guard = env_lock::lock_env([ - ("GOOSE_WORKER_CONTEXT_LIMIT", worker_limit), - ("GOOSE_CONTEXT_LIMIT", global_limit), - ]); - - let default_model = ModelConfig::new_or_fail("gpt-3.5-turbo") - .with_canonical_limits("openai") - .with_context_limit(Some(16_000)); - - let result = create_worker_model_config(&default_model, "openai").unwrap(); - assert_eq!(result.context_limit, Some(expected_limit)); - } - #[tokio::test] async fn test_tanzu_declarative_provider_registry_wiring() { let providers_list = providers().await; diff --git a/crates/goose/src/providers/lead_worker.rs b/crates/goose/src/providers/lead_worker.rs deleted file mode 100644 index 7674620ed5ae..000000000000 --- a/crates/goose/src/providers/lead_worker.rs +++ /dev/null @@ -1,745 +0,0 @@ -use anyhow::{anyhow, Result}; -use async_trait::async_trait; -use std::ops::Deref; -use std::sync::Arc; -use tokio::sync::Mutex; - -use super::base::{ - collect_stream, stream_from_single_message, LeadWorkerProviderTrait, MessageStream, Provider, - ProviderDef, ProviderMetadata, ProviderUsage, -}; -use super::errors::ProviderError; -use crate::conversation::message::{Message, MessageContent}; -use crate::model::ModelConfig; -use futures::future::BoxFuture; -use rmcp::model::Tool; -use rmcp::model::{Content, RawContent}; - -const LEAD_WORKER_PROVIDER_NAME: &str = "lead_worker"; - -/// A provider that switches between a lead model and a worker model based on turn count -/// and can fall back to lead model on consecutive failures -pub struct LeadWorkerProvider { - lead_provider: Arc, - worker_provider: Arc, - lead_turns: usize, - turn_count: Arc>, - failure_count: Arc>, - max_failures_before_fallback: usize, - fallback_turns: usize, - in_fallback_mode: Arc>, - fallback_remaining: Arc>, -} - -impl LeadWorkerProvider { - /// Create a new LeadWorkerProvider - /// - /// # Arguments - /// * `lead_provider` - The provider to use for the initial turns - /// * `worker_provider` - The provider to use after lead_turns - /// * `lead_turns` - Number of turns to use the lead provider (default: 3) - pub fn new( - lead_provider: Arc, - worker_provider: Arc, - lead_turns: Option, - ) -> Self { - Self { - lead_provider, - worker_provider, - lead_turns: lead_turns.unwrap_or(3), - turn_count: Arc::new(Mutex::new(0)), - failure_count: Arc::new(Mutex::new(0)), - max_failures_before_fallback: 2, // Fallback after 2 consecutive failures - fallback_turns: 2, // Use lead model for 2 turns when in fallback mode - in_fallback_mode: Arc::new(Mutex::new(false)), - fallback_remaining: Arc::new(Mutex::new(0)), - } - } - - /// Create a new LeadWorkerProvider with custom settings - /// - /// # Arguments - /// * `lead_provider` - The provider to use for the initial turns - /// * `worker_provider` - The provider to use after lead_turns - /// * `lead_turns` - Number of turns to use the lead provider - /// * `failure_threshold` - Number of consecutive failures before fallback - /// * `fallback_turns` - Number of turns to use lead model in fallback mode - pub fn new_with_settings( - lead_provider: Arc, - worker_provider: Arc, - lead_turns: usize, - failure_threshold: usize, - fallback_turns: usize, - ) -> Self { - Self { - lead_provider, - worker_provider, - lead_turns, - turn_count: Arc::new(Mutex::new(0)), - failure_count: Arc::new(Mutex::new(0)), - max_failures_before_fallback: failure_threshold, - fallback_turns, - in_fallback_mode: Arc::new(Mutex::new(false)), - fallback_remaining: Arc::new(Mutex::new(0)), - } - } - - /// Reset the turn counter and failure tracking (useful for new conversations) - pub async fn reset_turn_count(&self) { - let mut count = self.turn_count.lock().await; - *count = 0; - let mut failures = self.failure_count.lock().await; - *failures = 0; - let mut fallback = self.in_fallback_mode.lock().await; - *fallback = false; - let mut remaining = self.fallback_remaining.lock().await; - *remaining = 0; - } - - /// Get the current turn count - pub async fn get_turn_count(&self) -> usize { - *self.turn_count.lock().await - } - - /// Get the current failure count - pub async fn get_failure_count(&self) -> usize { - *self.failure_count.lock().await - } - - /// Check if currently in fallback mode - pub async fn is_in_fallback_mode(&self) -> bool { - *self.in_fallback_mode.lock().await - } - - /// Get the currently active provider based on turn count and fallback state - async fn get_active_provider(&self) -> Arc { - let count = *self.turn_count.lock().await; - let in_fallback = *self.in_fallback_mode.lock().await; - - // Use lead provider if we're in initial turns OR in fallback mode - if count < self.lead_turns || in_fallback { - Arc::clone(&self.lead_provider) - } else { - Arc::clone(&self.worker_provider) - } - } - - /// Handle the result of a completion attempt and update failure tracking - async fn handle_completion_result( - &self, - result: &Result<(Message, ProviderUsage), ProviderError>, - ) { - match result { - Ok((message, _usage)) => { - // Check for task-level failures in the response - let has_task_failure = self.detect_task_failures(message).await; - - if has_task_failure { - // Task failure detected - increment failure count - let mut failures = self.failure_count.lock().await; - *failures += 1; - - let failure_count = *failures; - let turn_count = *self.turn_count.lock().await; - - tracing::warn!( - "Task failure detected in response (failure count: {})", - failure_count - ); - - // Check if we should trigger fallback - if turn_count >= self.lead_turns - && !*self.in_fallback_mode.lock().await - && failure_count >= self.max_failures_before_fallback - { - let mut in_fallback = self.in_fallback_mode.lock().await; - let mut fallback_remaining = self.fallback_remaining.lock().await; - - *in_fallback = true; - *fallback_remaining = self.fallback_turns; - *failures = 0; // Reset failure count when entering fallback - - tracing::warn!( - "🔄 SWITCHING TO LEAD MODEL: Entering fallback mode after {} consecutive task failures - using lead model for {} turns", - self.max_failures_before_fallback, - self.fallback_turns - ); - } - } else { - // Success - reset failure count and handle fallback mode - let mut failures = self.failure_count.lock().await; - *failures = 0; - - let mut in_fallback = self.in_fallback_mode.lock().await; - let mut fallback_remaining = self.fallback_remaining.lock().await; - - if *in_fallback { - *fallback_remaining -= 1; - if *fallback_remaining == 0 { - *in_fallback = false; - tracing::info!("✅ SWITCHING BACK TO WORKER MODEL: Exiting fallback mode - worker model resumed"); - } - } - } - - // Increment turn count on any completion (success or task failure) - let mut count = self.turn_count.lock().await; - *count += 1; - } - Err(_) => { - // Technical failure - just log and let it bubble up - // For technical failures (API/LLM issues), we don't want to second-guess - // the model choice - just let the default model handle it - tracing::warn!( - "Technical failure detected - API/LLM issue, will use default model" - ); - - // Don't increment turn count or failure tracking for technical failures - // as these are temporary infrastructure issues, not model capability issues - } - } - } - - /// Detect task-level failures in the model's response - async fn detect_task_failures(&self, message: &Message) -> bool { - let mut failure_indicators = 0; - - for content in &message.content { - match content { - MessageContent::ToolRequest(tool_request) => { - // Check if tool request itself failed (malformed, etc.) - if tool_request.tool_call.is_err() { - failure_indicators += 1; - tracing::debug!( - "Failed tool request detected: {:?}", - tool_request.tool_call - ); - } - } - MessageContent::ToolResponse(tool_response) => { - // Check if tool execution failed - if let Err(tool_error) = &tool_response.tool_result { - failure_indicators += 1; - tracing::debug!("Tool execution failure detected: {:?}", tool_error); - } else if let Ok(result) = &tool_response.tool_result { - // Check tool output for error indicators - if self.contains_error_indicators(&result.content) { - failure_indicators += 1; - tracing::debug!("Tool output contains error indicators"); - } - } - } - MessageContent::Text(text_content) => { - // Check for user correction patterns or error acknowledgments - if self.contains_user_correction_patterns(&text_content.text) { - failure_indicators += 1; - tracing::debug!("User correction pattern detected in text"); - } - } - _ => {} - } - } - - // Consider it a failure if we have multiple failure indicators - failure_indicators >= 1 - } - - /// Check if tool output contains error indicators - fn contains_error_indicators(&self, contents: &[Content]) -> bool { - for content in contents { - if let RawContent::Text(text_content) = content.deref() { - let text_lower = text_content.text.to_lowercase(); - - // Common error patterns in tool outputs - if text_lower.contains("error:") - || text_lower.contains("failed:") - || text_lower.contains("exception:") - || text_lower.contains("traceback") - || text_lower.contains("syntax error") - || text_lower.contains("permission denied") - || text_lower.contains("file not found") - || text_lower.contains("command not found") - || text_lower.contains("compilation failed") - || text_lower.contains("test failed") - || text_lower.contains("assertion failed") - { - return true; - } - } - } - false - } - - /// Check for user correction patterns in text - fn contains_user_correction_patterns(&self, text: &str) -> bool { - let text_lower = text.to_lowercase(); - - // Patterns indicating user is correcting or expressing dissatisfaction - text_lower.contains("that's wrong") - || text_lower.contains("that's not right") - || text_lower.contains("that doesn't work") - || text_lower.contains("try again") - || text_lower.contains("let me correct") - || text_lower.contains("actually, ") - || text_lower.contains("no, that's") - || text_lower.contains("that's incorrect") - || text_lower.contains("fix this") - || text_lower.contains("this is broken") - || text_lower.contains("this doesn't") - || text_lower.starts_with("no,") - || text_lower.starts_with("wrong") - || text_lower.starts_with("incorrect") - } -} - -impl LeadWorkerProviderTrait for LeadWorkerProvider { - /// Get information about the lead and worker models for logging - fn get_model_info(&self) -> (String, String) { - let lead_model = self.lead_provider.get_model_config().model_name; - let worker_model = self.worker_provider.get_model_config().model_name; - (lead_model, worker_model) - } - - /// Get the currently active model name - fn get_active_model(&self) -> String { - // Read from the global store which was set during complete() - use super::base::get_current_model; - get_current_model().unwrap_or_else(|| { - // Fallback to lead model if no current model is set - self.lead_provider.get_model_config().model_name - }) - } - - /// Get (lead_turns, failure_threshold, fallback_turns) - fn get_settings(&self) -> (usize, usize, usize) { - ( - self.lead_turns, - self.max_failures_before_fallback, - self.fallback_turns, - ) - } -} - -impl ProviderDef for LeadWorkerProvider { - type Provider = Self; - - fn metadata() -> ProviderMetadata { - // This is a wrapper provider, so we return minimal metadata - ProviderMetadata::new( - LEAD_WORKER_PROVIDER_NAME, - "Lead/Worker Provider", - "A provider that switches between lead and worker models based on turn count", - "", // No default model as this is determined by the wrapped providers - vec![], // No known models as this depends on wrapped providers - "", // No doc link - vec![], // No config keys as configuration is done through wrapped providers - ) - } - - fn from_env( - _model: ModelConfig, - _extensions: Vec, - ) -> BoxFuture<'static, Result> { - Box::pin(async { Err(anyhow!("LeadWorkerProvider must be constructed explicitly")) }) - } -} - -#[async_trait] -impl Provider for LeadWorkerProvider { - fn get_name(&self) -> &str { - // Return the lead provider's name as the default - self.lead_provider.get_name() - } - - fn get_model_config(&self) -> ModelConfig { - // Return the lead provider's model config as the default - // In practice, this might need to be more sophisticated - self.lead_provider.get_model_config() - } - - async fn stream( - &self, - _model_config: &ModelConfig, - session_id: &str, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result { - // Get the active provider - let provider = self.get_active_provider().await; - - // Log which provider is being used - let turn_count = *self.turn_count.lock().await; - let in_fallback = *self.in_fallback_mode.lock().await; - let fallback_remaining = *self.fallback_remaining.lock().await; - - let provider_type = if turn_count < self.lead_turns { - "lead (initial)" - } else if in_fallback { - "lead (fallback)" - } else { - "worker" - }; - - // Get the active model name and update the global store - let active_model_name = if turn_count < self.lead_turns || in_fallback { - self.lead_provider.get_model_config().model_name.clone() - } else { - self.worker_provider.get_model_config().model_name.clone() - }; - - // Update the global current model store - super::base::set_current_model(&active_model_name); - - if in_fallback { - tracing::info!( - "🔄 Using {} provider for turn {} (FALLBACK MODE: {} turns remaining) - Model: {}", - provider_type, - turn_count + 1, - fallback_remaining, - active_model_name - ); - } else { - tracing::info!( - "Using {} provider for turn {} (lead_turns: {}) - Model: {}", - provider_type, - turn_count + 1, - self.lead_turns, - active_model_name - ); - } - - // Make the completion request - let model_config = provider.get_model_config(); - let stream_result = provider - .stream(&model_config, session_id, system, messages, tools) - .await; - let result = match stream_result { - Ok(stream) => collect_stream(stream).await, - Err(e) => Err(e), - }; - - // For technical failures, try with default model (lead provider) instead - let final_result = match &result { - Err(_) => { - tracing::warn!("Technical failure with {} provider, retrying with default model (lead provider)", provider_type); - - // Try with lead provider as the default/fallback for technical failures - let model_config = self.lead_provider.get_model_config(); - let default_stream_result = self - .lead_provider - .stream(&model_config, session_id, system, messages, tools) - .await; - let default_result = match default_stream_result { - Ok(stream) => collect_stream(stream).await, - Err(e) => Err(e), - }; - - match &default_result { - Ok(_) => { - tracing::info!( - "✅ Default model (lead provider) succeeded after technical failure" - ); - default_result - } - Err(_) => { - tracing::error!("❌ Default model (lead provider) also failed - returning original error"); - result // Return the original error - } - } - } - Ok(_) => result, // Success with original provider - }; - - // Handle the result and update tracking (only for successful completions) - self.handle_completion_result(&final_result).await; - - match final_result { - Ok((message, usage)) => Ok(stream_from_single_message(message, usage)), - Err(e) => Err(e), - } - } - - async fn fetch_supported_models(&self) -> Result, ProviderError> { - // Combine models from both providers - let mut all_models = self.lead_provider.fetch_supported_models().await?; - let worker_models = self.worker_provider.fetch_supported_models().await?; - all_models.extend(worker_models); - all_models.sort(); - all_models.dedup(); - Ok(all_models) - } - - fn supports_embeddings(&self) -> bool { - // Support embeddings if either provider supports them - self.lead_provider.supports_embeddings() || self.worker_provider.supports_embeddings() - } - - async fn create_embeddings( - &self, - session_id: &str, - texts: Vec, - ) -> Result>, ProviderError> { - // Use the lead provider for embeddings if it supports them, otherwise use worker - if self.lead_provider.supports_embeddings() { - self.lead_provider - .create_embeddings(session_id, texts) - .await - } else if self.worker_provider.supports_embeddings() { - self.worker_provider - .create_embeddings(session_id, texts) - .await - } else { - Err(ProviderError::ExecutionError( - "Neither lead nor worker provider supports embeddings".to_string(), - )) - } - } - - /// Check if this provider is a LeadWorkerProvider - fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> { - Some(self) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::conversation::message::{Message, MessageContent}; - use crate::providers::base::{ProviderUsage, Usage}; - use chrono::Utc; - use rmcp::model::{AnnotateAble, RawTextContent, Role}; - - #[derive(Clone)] - struct MockProvider { - name: String, - model_config: ModelConfig, - } - - #[async_trait] - impl Provider for MockProvider { - fn get_name(&self) -> &str { - "mock-lead" - } - - fn get_model_config(&self) -> ModelConfig { - self.model_config.clone() - } - - async fn stream( - &self, - _model_config: &ModelConfig, - _session_id: &str, - _system: &str, - _messages: &[Message], - _tools: &[Tool], - ) -> Result { - let message = Message::new( - Role::Assistant, - Utc::now().timestamp(), - vec![MessageContent::Text( - RawTextContent { - text: format!("Response from {}", self.name), - meta: None, - } - .no_annotation(), - )], - ); - let usage = ProviderUsage::new(self.name.clone(), Usage::default()); - Ok(stream_from_single_message(message, usage)) - } - } - - #[tokio::test] - async fn test_lead_worker_switching() { - let lead_provider = Arc::new(MockProvider { - name: "lead".to_string(), - model_config: ModelConfig::new_or_fail("lead-model"), - }); - - let worker_provider = Arc::new(MockProvider { - name: "worker".to_string(), - model_config: ModelConfig::new_or_fail("worker-model"), - }); - - let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(3)); - let model_config = provider.get_model_config(); - - // First three turns should use lead provider - for i in 0..3 { - let (_message, usage) = provider - .complete(&model_config, "test-session-id", "system", &[], &[]) - .await - .unwrap(); - assert_eq!(usage.model, "lead"); - assert_eq!(provider.get_turn_count().await, i + 1); - assert!(!provider.is_in_fallback_mode().await); - } - - // Subsequent turns should use worker provider - for i in 3..6 { - let (_message, usage) = provider - .complete(&model_config, "test-session-id", "system", &[], &[]) - .await - .unwrap(); - assert_eq!(usage.model, "worker"); - assert_eq!(provider.get_turn_count().await, i + 1); - assert!(!provider.is_in_fallback_mode().await); - } - - // Reset and verify it goes back to lead - provider.reset_turn_count().await; - assert_eq!(provider.get_turn_count().await, 0); - assert_eq!(provider.get_failure_count().await, 0); - assert!(!provider.is_in_fallback_mode().await); - - let (_message, usage) = provider - .complete(&model_config, "test-session-id", "system", &[], &[]) - .await - .unwrap(); - assert_eq!(usage.model, "lead"); - } - - #[tokio::test] - async fn test_technical_failure_retry() { - let lead_provider = Arc::new(MockFailureProvider { - name: "lead".to_string(), - model_config: ModelConfig::new_or_fail("lead-model"), - should_fail: false, // Lead provider works - }); - - let worker_provider = Arc::new(MockFailureProvider { - name: "worker".to_string(), - model_config: ModelConfig::new_or_fail("worker-model"), - should_fail: true, // Worker will fail - }); - - let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2)); - let model_config = provider.get_model_config(); - - // First two turns use lead (should succeed) - for _i in 0..2 { - let result = provider - .complete(&model_config, "test-session-id", "system", &[], &[]) - .await; - assert!(result.is_ok()); - assert_eq!(result.unwrap().1.model, "lead"); - assert!(!provider.is_in_fallback_mode().await); - } - - // Next turn uses worker (will fail, but should retry with lead and succeed) - let model_config = provider.get_model_config(); - let result = provider - .complete(&model_config, "test-session-id", "system", &[], &[]) - .await; - assert!(result.is_ok()); // Should succeed because lead provider is used as fallback - assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider - assert_eq!(provider.get_failure_count().await, 0); // No failure tracking for technical failures - assert!(!provider.is_in_fallback_mode().await); // Not in fallback mode - - // Another turn - should still try worker first, then retry with lead - let model_config = provider.get_model_config(); - let result = provider - .complete(&model_config, "test-session-id", "system", &[], &[]) - .await; - assert!(result.is_ok()); // Should succeed because lead provider is used as fallback - assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider - assert_eq!(provider.get_failure_count().await, 0); // Still no failure tracking - assert!(!provider.is_in_fallback_mode().await); // Still not in fallback mode - } - - #[tokio::test] - async fn test_fallback_on_task_failures() { - // Test that task failures (not technical failures) still trigger fallback mode - // This would need a different mock that simulates task failures in successful responses - // For now, we'll test the fallback mode functionality directly - let lead_provider = Arc::new(MockFailureProvider { - name: "lead".to_string(), - model_config: ModelConfig::new_or_fail("lead-model"), - should_fail: false, - }); - - let worker_provider = Arc::new(MockFailureProvider { - name: "worker".to_string(), - model_config: ModelConfig::new_or_fail("worker-model"), - should_fail: false, - }); - - let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2)); - - // Simulate being in fallback mode - { - let mut in_fallback = provider.in_fallback_mode.lock().await; - *in_fallback = true; - let mut fallback_remaining = provider.fallback_remaining.lock().await; - *fallback_remaining = 2; - let mut turn_count = provider.turn_count.lock().await; - *turn_count = 4; // Past initial lead turns - } - - // Should use lead provider in fallback mode - let model_config = provider.get_model_config(); - let result = provider - .complete(&model_config, "test-session-id", "system", &[], &[]) - .await; - assert!(result.is_ok()); - assert_eq!(result.unwrap().1.model, "lead"); - assert!(provider.is_in_fallback_mode().await); - - // One more fallback turn - let model_config = provider.get_model_config(); - let result = provider - .complete(&model_config, "test-session-id", "system", &[], &[]) - .await; - assert!(result.is_ok()); - assert_eq!(result.unwrap().1.model, "lead"); - assert!(!provider.is_in_fallback_mode().await); // Should exit fallback mode - } - - #[derive(Clone)] - struct MockFailureProvider { - name: String, - model_config: ModelConfig, - should_fail: bool, - } - - #[async_trait] - impl Provider for MockFailureProvider { - fn get_name(&self) -> &str { - "mock-lead" - } - - fn get_model_config(&self) -> ModelConfig { - self.model_config.clone() - } - - async fn stream( - &self, - _model_config: &ModelConfig, - _session_id: &str, - _system: &str, - _messages: &[Message], - _tools: &[Tool], - ) -> Result { - if self.should_fail { - Err(ProviderError::ExecutionError( - "Simulated failure".to_string(), - )) - } else { - let message = Message::new( - Role::Assistant, - Utc::now().timestamp(), - vec![MessageContent::Text( - RawTextContent { - text: format!("Response from {}", self.name), - meta: None, - } - .no_annotation(), - )], - ); - let usage = ProviderUsage::new(self.name.clone(), Usage::default()); - Ok(stream_from_single_message(message, usage)) - } - } - } -} diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 333f791122a1..b1d51b3de0cc 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -25,7 +25,6 @@ pub mod gemini_cli; pub mod githubcopilot; pub mod google; mod init; -pub mod lead_worker; pub mod litellm; pub mod local_inference; pub mod nanogpt; diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 23adc483fc32..a267e7237307 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -465,7 +465,6 @@ mod tests { responses.push(response); } Ok(AgentEvent::McpNotification(_)) => {} - Ok(AgentEvent::ModelChange { .. }) => {} Ok(AgentEvent::HistoryReplaced(_updated_conversation)) => { // We should update the conversation here, but we're not reading it } diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index 2a81f2ee4189..1bc4f207affd 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -6113,28 +6113,6 @@ } } }, - { - "type": "object", - "required": [ - "model", - "mode", - "type" - ], - "properties": { - "mode": { - "type": "string" - }, - "model": { - "type": "string" - }, - "type": { - "type": "string", - "enum": [ - "ModelChange" - ] - } - } - }, { "type": "object", "required": [ diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index 10b7bfeaccd8..6add7592e84e 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -687,10 +687,6 @@ export type MessageEvent = { reason: string; token_state: TokenState; type: 'Finish'; -} | { - mode: string; - model: string; - type: 'ModelChange'; } | { message: { [key: string]: unknown; diff --git a/ui/desktop/src/components/BaseChat.tsx b/ui/desktop/src/components/BaseChat.tsx index b88c9cc71713..6ab1a9c4b065 100644 --- a/ui/desktop/src/components/BaseChat.tsx +++ b/ui/desktop/src/components/BaseChat.tsx @@ -1,8 +1,6 @@ import { AppEvents } from '../constants/events'; import React, { - createContext, useCallback, - useContext, useEffect, useMemo, useRef, @@ -42,8 +40,7 @@ import { useAutoSubmit } from '../hooks/useAutoSubmit'; import { Goose } from './icons'; import EnvironmentBadge from './GooseSidebar/EnvironmentBadge'; -const CurrentModelContext = createContext<{ model: string; mode: string } | null>(null); -export const useCurrentModelInfo = () => useContext(CurrentModelContext); + interface BaseChatProps { setChat: (chat: ChatType) => void; diff --git a/ui/desktop/src/components/settings/models/bottom_bar/ModelsBottomBar.tsx b/ui/desktop/src/components/settings/models/bottom_bar/ModelsBottomBar.tsx index e3e1a51b4c73..8e4d956dc7e5 100644 --- a/ui/desktop/src/components/settings/models/bottom_bar/ModelsBottomBar.tsx +++ b/ui/desktop/src/components/settings/models/bottom_bar/ModelsBottomBar.tsx @@ -2,7 +2,6 @@ import { Sliders, Bot, Settings } from 'lucide-react'; import React, { useEffect, useState } from 'react'; import { useModelAndProvider } from '../../../ModelAndProviderContext'; import { SwitchModelModal } from '../subcomponents/SwitchModelModal'; -import { LeadWorkerSettings } from '../subcomponents/LeadWorkerSettings'; import { View } from '../../../../utils/navigationUtils'; import { DropdownMenu, @@ -10,7 +9,6 @@ import { DropdownMenuItem, DropdownMenuTrigger, } from '../../../ui/dropdown-menu'; -import { useCurrentModelInfo } from '../../../BaseChat'; import { useConfig } from '../../../ConfigContext'; import { getProviderMetadata } from '../modelInterface'; import { getModelDisplayName } from '../predefinedModelsUtils'; @@ -46,81 +44,17 @@ export default function ModelsBottomBar({ const currentModel = sessionModel ?? configModel; const currentProvider = sessionProvider ?? configProvider; - const currentModelInfo = useCurrentModelInfo(); - const { read, getProviders } = useConfig(); + const { getProviders } = useConfig(); const [displayProvider, setDisplayProvider] = useState(null); const [displayModelName, setDisplayModelName] = useState('Select Model'); const [isAddModelModalOpen, setIsAddModelModalOpen] = useState(false); - const [isLeadWorkerModalOpen, setIsLeadWorkerModalOpen] = useState(false); const [isLocalModelSettingsOpen, setIsLocalModelSettingsOpen] = useState(false); - const [isLeadWorkerActive, setIsLeadWorkerActive] = useState(false); const [providerDefaultModel, setProviderDefaultModel] = useState(null); - // Check if lead/worker mode is active - useEffect(() => { - const checkLeadWorker = async () => { - try { - const leadModel = await read('GOOSE_LEAD_MODEL', false); - setIsLeadWorkerActive(!!leadModel); - } catch (error) { - console.error('Error checking lead model:', error); - setIsLeadWorkerActive(false); - } - }; - checkLeadWorker(); - }, [read]); - - // Refresh lead/worker status when modal closes - const handleLeadWorkerModalClose = () => { - setIsLeadWorkerModalOpen(false); - const checkLeadWorker = async () => { - try { - const leadModel = await read('GOOSE_LEAD_MODEL', false); - const currentModel = await read('GOOSE_MODEL', false); - setIsLeadWorkerActive(!!leadModel); - setLeadModelName((leadModel as string) || ''); - setCurrentActiveModel((currentModel as string) || ''); - } catch (error) { - console.error('Error checking lead model after modal close:', error); - setIsLeadWorkerActive(false); - } - }; - checkLeadWorker(); - }; - - const [leadModelName, setLeadModelName] = useState(''); - const [currentActiveModel, setCurrentActiveModel] = useState(''); - - // Get lead model name and current model for comparison - useEffect(() => { - const getModelInfo = async () => { - try { - const leadModel = await read('GOOSE_LEAD_MODEL', false); - const currentModel = await read('GOOSE_MODEL', false); - setLeadModelName((leadModel as string) || ''); - setCurrentActiveModel((currentModel as string) || ''); - } catch (error) { - console.error('Error getting model info:', error); - } - }; - getModelInfo(); - }, [read]); - - // Determine the mode based on which model is currently active - const modelMode = isLeadWorkerActive - ? currentActiveModel === leadModelName - ? 'lead' - : 'worker' - : undefined; - - // Determine which model to display - activeModel takes priority when lead/worker is active // Hide label while session data is still being fetched (avoids flashing // the config default before the session's actual model arrives). const isModelLoading = sessionId && !sessionLoaded; - const displayModel = - isLeadWorkerActive && currentModelInfo?.model - ? currentModelInfo.model - : currentModel || providerDefaultModel || displayModelName; + const displayModel = currentModel || providerDefaultModel || displayModelName; useEffect(() => { if (!currentProvider) return; @@ -168,9 +102,6 @@ export default function ModelsBottomBar({ {displayModel} - {isLeadWorkerActive && modelMode && ( - ({modelMode}) - )} @@ -184,10 +115,6 @@ export default function ModelsBottomBar({ Change Model - setIsLeadWorkerModalOpen(true)}> - Lead/Worker Settings - - {currentProvider === 'local' && currentModel && ( setIsLocalModelSettingsOpen(true)}> Local Model Settings @@ -208,10 +135,6 @@ export default function ModelsBottomBar({ /> ) : null} - {isLeadWorkerModalOpen ? ( - - ) : null} - {isLocalModelSettingsOpen && currentModel && (
diff --git a/ui/desktop/src/components/settings/models/subcomponents/LeadWorkerSettings.test.tsx b/ui/desktop/src/components/settings/models/subcomponents/LeadWorkerSettings.test.tsx deleted file mode 100644 index 5dcebda3cace..000000000000 --- a/ui/desktop/src/components/settings/models/subcomponents/LeadWorkerSettings.test.tsx +++ /dev/null @@ -1,147 +0,0 @@ -import { describe, it, expect, vi, beforeEach } from 'vitest'; -import { render, screen, waitFor, fireEvent } from '@testing-library/react'; -import { LeadWorkerSettings } from './LeadWorkerSettings'; - -// Mock predefined models utils to force provider-based options (no predefined list) -vi.mock('../predefinedModelsUtils', () => ({ - shouldShowPredefinedModels: () => false, - getPredefinedModelsFromEnv: () => [], -})); - -// Mocks for useConfig -const mockRead = vi.fn(); -const mockUpsert = vi.fn(); -const mockRemove = vi.fn(); -const mockGetProviders = vi.fn(); - -vi.mock('../../../ConfigContext', () => ({ - useConfig: () => ({ - read: mockRead, - upsert: mockUpsert, - remove: mockRemove, - getProviders: mockGetProviders, - }), -})); - -describe('LeadWorkerSettings', () => { - beforeEach(() => { - vi.clearAllMocks(); - }); - - const setupHappyPathMocks = () => { - // reads - mockRead.mockImplementation(async (key: string) => { - switch (key) { - case 'GOOSE_LEAD_MODEL': - return 'my-custom-lead'; - case 'GOOSE_LEAD_PROVIDER': - return 'anthropic'; - case 'GOOSE_LEAD_TURNS': - return 3; - case 'GOOSE_LEAD_FAILURE_THRESHOLD': - return 2; - case 'GOOSE_LEAD_FALLBACK_TURNS': - return 2; - case 'GOOSE_MODEL': - return 'my-custom-worker'; - case 'GOOSE_PROVIDER': - return 'openai'; - default: - return null; - } - }); - - // providers (options do NOT include the custom models above) - mockGetProviders.mockResolvedValue([ - { - is_configured: true, - name: 'openai', - metadata: { - display_name: 'OpenAI', - known_models: [{ name: 'gpt-4o' }, { name: 'gpt-4o-mini' }], - }, - }, - { - is_configured: true, - name: 'anthropic', - metadata: { - display_name: 'Anthropic', - known_models: [{ name: 'claude-3-5-sonnet' }], - }, - }, - ]); - - // writers - mockUpsert.mockResolvedValue(undefined); - mockRemove.mockResolvedValue(undefined); - }; - - it('shows custom inputs for lead/worker when current models are unknown and saves them', async () => { - setupHappyPathMocks(); - - const onClose = vi.fn(); - render(); - - // Wait for modal content (not loading) - await waitFor(() => { - expect(screen.getByText('Lead/Worker Mode')).toBeInTheDocument(); - }); - - // Labels should be present with back-to-list controls - await waitFor(() => { - expect(screen.getByText('Lead Model')).toBeInTheDocument(); - expect(screen.getByText('Worker Model')).toBeInTheDocument(); - // Back to model list appears for each section when in custom mode - const backLinks = screen.getAllByText('Back to model list'); - expect(backLinks.length).toBeGreaterThanOrEqual(2); - }); - - const inputs = screen.getAllByPlaceholderText('Type model name here') as HTMLInputElement[]; - expect(inputs.length).toBe(2); - const [leadInput, workerInput] = inputs; - expect(leadInput.value).toBe('my-custom-lead'); - expect(workerInput.value).toBe('my-custom-worker'); - - // Save settings - const saveBtn = screen.getByRole('button', { name: 'Save Settings' }); - expect(saveBtn).toBeEnabled(); - fireEvent.click(saveBtn); - - // Assert upserts for models (providers are optional but present in this setup) - await waitFor(() => { - expect(mockUpsert).toHaveBeenCalledWith('GOOSE_LEAD_MODEL', 'my-custom-lead', false); - expect(mockUpsert).toHaveBeenCalledWith('GOOSE_MODEL', 'my-custom-worker', false); - expect(mockUpsert).toHaveBeenCalledWith('GOOSE_LEAD_PROVIDER', 'anthropic', false); - expect(mockUpsert).toHaveBeenCalledWith('GOOSE_PROVIDER', 'openai', false); - }); - }); - - it('disables lead/worker and removes config when toggled off', async () => { - setupHappyPathMocks(); - - const onClose = vi.fn(); - render(); - - await waitFor(() => { - expect(screen.getByText('Lead/Worker Mode')).toBeInTheDocument(); - }); - - // Toggle off - const checkbox = screen.getByLabelText('Enable lead/worker mode') as HTMLInputElement; - expect(checkbox.checked).toBe(true); - fireEvent.click(checkbox); - expect(checkbox.checked).toBe(false); - - const saveBtn = screen.getByRole('button', { name: 'Save Settings' }); - expect(saveBtn).toBeEnabled(); - fireEvent.click(saveBtn); - - await waitFor(() => { - expect(mockRemove).toHaveBeenCalledWith('GOOSE_LEAD_MODEL', false); - expect(mockRemove).toHaveBeenCalledWith('GOOSE_LEAD_PROVIDER', false); - expect(mockRemove).toHaveBeenCalledWith('GOOSE_LEAD_TURNS', false); - expect(mockRemove).toHaveBeenCalledWith('GOOSE_LEAD_FAILURE_THRESHOLD', false); - expect(mockRemove).toHaveBeenCalledWith('GOOSE_LEAD_FALLBACK_TURNS', false); - }); - }); -}); diff --git a/ui/desktop/src/components/settings/models/subcomponents/LeadWorkerSettings.tsx b/ui/desktop/src/components/settings/models/subcomponents/LeadWorkerSettings.tsx deleted file mode 100644 index 8acd5648ab76..000000000000 --- a/ui/desktop/src/components/settings/models/subcomponents/LeadWorkerSettings.tsx +++ /dev/null @@ -1,419 +0,0 @@ -import { useState, useEffect } from 'react'; -import { useConfig } from '../../../ConfigContext'; -import { Button } from '../../../ui/button'; -import { Select } from '../../../ui/Select'; -import { Input } from '../../../ui/input'; -import { getPredefinedModelsFromEnv, shouldShowPredefinedModels } from '../predefinedModelsUtils'; -import { Dialog, DialogContent, DialogHeader, DialogTitle } from '../../../ui/dialog'; -import { fetchModelsForProviders } from '../modelInterface'; - -interface LeadWorkerSettingsProps { - isOpen: boolean; - onClose: () => void; -} - -export function LeadWorkerSettings({ isOpen, onClose }: LeadWorkerSettingsProps) { - const { read, upsert, getProviders, remove } = useConfig(); - const [leadModel, setLeadModel] = useState(''); - const [workerModel, setWorkerModel] = useState(''); - const [leadProvider, setLeadProvider] = useState(''); - const [workerProvider, setWorkerProvider] = useState(''); - // Minimal custom model mode toggles - const [isLeadCustomModel, setIsLeadCustomModel] = useState(false); - const [isWorkerCustomModel, setIsWorkerCustomModel] = useState(false); - const [leadTurns, setLeadTurns] = useState(3); - const [failureThreshold, setFailureThreshold] = useState(2); - const [fallbackTurns, setFallbackTurns] = useState(2); - const [isEnabled, setIsEnabled] = useState(false); - const [modelOptions, setModelOptions] = useState< - { value: string; label: string; provider: string }[] - >([]); - const [isLoading, setIsLoading] = useState(true); - - // Load current configuration - useEffect(() => { - if (!isOpen) return; // Only load when modal is open - - const loadConfig = async () => { - try { - setIsLoading(true); - const [ - leadModelConfig, - leadProviderConfig, - leadTurnsConfig, - failureThresholdConfig, - fallbackTurnsConfig, - ] = await Promise.all([ - read('GOOSE_LEAD_MODEL', false), - read('GOOSE_LEAD_PROVIDER', false), - read('GOOSE_LEAD_TURNS', false), - read('GOOSE_LEAD_FAILURE_THRESHOLD', false), - read('GOOSE_LEAD_FALLBACK_TURNS', false), - ]); - - if (leadModelConfig) { - setLeadModel(leadModelConfig as string); - setIsEnabled(true); - } else { - setLeadModel(''); - setIsEnabled(false); - } - if (leadProviderConfig) setLeadProvider(leadProviderConfig as string); - else setLeadProvider(''); - if (leadTurnsConfig) setLeadTurns(Number(leadTurnsConfig)); - else setLeadTurns(3); - if (failureThresholdConfig) setFailureThreshold(Number(failureThresholdConfig)); - else setFailureThreshold(2); - if (fallbackTurnsConfig) setFallbackTurns(Number(fallbackTurnsConfig)); - else setFallbackTurns(2); - - // Set worker model from config - const workerModelConfig = await read('GOOSE_MODEL', false); - if (workerModelConfig) { - setWorkerModel(workerModelConfig as string); - } else { - setWorkerModel(''); - } - - const workerProviderConfig = await read('GOOSE_PROVIDER', false); - if (workerProviderConfig) { - setWorkerProvider(workerProviderConfig as string); - } else { - setWorkerProvider(''); - } - - // Load available models - const options: { value: string; label: string; provider: string }[] = []; - - if (shouldShowPredefinedModels()) { - // Use predefined models if available - const predefinedModels = getPredefinedModelsFromEnv(); - predefinedModels.forEach((model) => { - options.push({ - value: model.name, // Use name for switching - label: model.alias || model.name, // Use alias for display, fall back to name - provider: model.provider, - }); - }); - } else { - // Fallback to provider-based models - const providers = await getProviders(false); - const activeProviders = providers.filter((p) => p.is_configured); - - const results = await fetchModelsForProviders(activeProviders); - - results.forEach(({ provider: p, models, error }) => { - if (error) { - console.error(error); - } - - if (models && models.length > 0) { - models.forEach((modelName) => { - options.push({ - value: modelName, - label: `${modelName} (${p.metadata.display_name})`, - provider: p.name, - }); - }); - } - // Add custom model option for all non-Custom providers - if (p.provider_type !== 'Custom') { - options.push({ - value: `__custom__:${p.name}`, - label: 'Enter a model not listed...', - provider: p.name, - }); - } - }); - } - - setModelOptions(options); - } catch (error) { - console.error('Error loading configuration:', error); - } finally { - setIsLoading(false); - } - }; - - loadConfig(); - }, [read, getProviders, isOpen]); - - // If current models are not in the list (e.g., previously set to custom), switch to custom mode - useEffect(() => { - if (!isLoading) { - if (leadModel && !modelOptions.find((opt) => opt.value === leadModel)) { - setIsLeadCustomModel(true); - } - if (workerModel && !modelOptions.find((opt) => opt.value === workerModel)) { - setIsWorkerCustomModel(true); - } - } - }, [isLoading, modelOptions, leadModel, workerModel]); - - const handleSave = async () => { - try { - if (isEnabled && leadModel && workerModel) { - // Save lead/worker configuration - await Promise.all([ - upsert('GOOSE_LEAD_MODEL', leadModel, false), - leadProvider && upsert('GOOSE_LEAD_PROVIDER', leadProvider, false), - upsert('GOOSE_MODEL', workerModel, false), - workerProvider && upsert('GOOSE_PROVIDER', workerProvider, false), - upsert('GOOSE_LEAD_TURNS', leadTurns, false), - upsert('GOOSE_LEAD_FAILURE_THRESHOLD', failureThreshold, false), - upsert('GOOSE_LEAD_FALLBACK_TURNS', fallbackTurns, false), - ]); - } else { - // Remove lead/worker configuration - await Promise.all([ - remove('GOOSE_LEAD_MODEL', false), - remove('GOOSE_LEAD_PROVIDER', false), - remove('GOOSE_LEAD_TURNS', false), - remove('GOOSE_LEAD_FAILURE_THRESHOLD', false), - remove('GOOSE_LEAD_FALLBACK_TURNS', false), - ]); - } - onClose(); - } catch (error) { - console.error('Error saving configuration:', error); - } - }; - - if (isLoading) { - return ( - !open && onClose()}> - - - Lead/Worker Mode - -
Loading...
-
-
- ); - } - - return ( - !open && onClose()}> - - - Lead/Worker Mode - -
-
-

- Configure a lead model for planning and a worker model for execution -

-
- -
- setIsEnabled(e.target.checked)} - className="rounded border-border-primary" - /> - -
- -
-
-
- - {isLeadCustomModel && ( - - )} -
- {!isLeadCustomModel ? ( - setLeadModel(event.target.value)} - value={leadModel} - disabled={!isEnabled} - /> - )} -

- Strong model for initial planning and fallback recovery -

-
- -
-
- - {isWorkerCustomModel && ( - - )} -
- {!isWorkerCustomModel ? ( - setWorkerModel(event.target.value)} - value={workerModel} - disabled={!isEnabled} - /> - )} -

- Fast model for routine execution tasks -

-
- -
-
- - setLeadTurns(Number(e.target.value))} - className={`w-20 ${!isEnabled ? 'opacity-50 cursor-not-allowed' : ''}`} - disabled={!isEnabled} - /> -

- Number of turns to use the lead model at the start -

-
- -
- - setFailureThreshold(Number(e.target.value))} - className={`w-20 ${!isEnabled ? 'opacity-50 cursor-not-allowed' : ''}`} - disabled={!isEnabled} - /> -

- Consecutive failures before switching back to lead -

-
- -
- - setFallbackTurns(Number(e.target.value))} - className={`w-20 ${!isEnabled ? 'opacity-50 cursor-not-allowed' : ''}`} - disabled={!isEnabled} - /> -

- Turns to use lead model during fallback -

-
-
-
- -
- - -
-
-
-
- ); -} diff --git a/ui/desktop/src/hooks/useChatStream.ts b/ui/desktop/src/hooks/useChatStream.ts index 6a94b0933888..0306b3ea9041 100644 --- a/ui/desktop/src/hooks/useChatStream.ts +++ b/ui/desktop/src/hooks/useChatStream.ts @@ -318,9 +318,6 @@ function createEventProcessor( onFinish(); return true; } - case 'ModelChange': { - return false; - } case 'UpdateConversation': { const conversation = (event as Record).conversation as Message[]; currentMessages = conversation; diff --git a/ui/desktop/src/utils/configUtils.ts b/ui/desktop/src/utils/configUtils.ts index e8c3625b2d31..0696f7ae277b 100644 --- a/ui/desktop/src/utils/configUtils.ts +++ b/ui/desktop/src/utils/configUtils.ts @@ -4,8 +4,6 @@ export const configLabels: Record = { GOOSE_MODEL: 'Model', GOOSE_TEMPERATURE: 'Temperature', GOOSE_MODE: 'Mode', - GOOSE_LEAD_PROVIDER: 'Lead Provider', - GOOSE_LEAD_MODEL: 'Lead Model', GOOSE_PLANNER_PROVIDER: 'Planner Provider', GOOSE_PLANNER_MODEL: 'Planner Model', GOOSE_TOOLSHIM: 'Tool Shim',