diff --git a/crates/goose-cli/src/commands/acp.rs b/crates/goose-cli/src/commands/acp.rs index 7717c5d5c92b..8ee9acfe09d4 100644 --- a/crates/goose-cli/src/commands/acp.rs +++ b/crates/goose-cli/src/commands/acp.rs @@ -117,7 +117,7 @@ impl GooseAcpAgent { toolshim_model: None, fast_model: None, }; - let provider = create(&provider_name, model_config)?; + let provider = create(&provider_name, model_config).await?; // Create a shared agent instance let agent = Agent::new(); diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index c1d3b53041b3..f1824d935708 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -258,7 +258,7 @@ async fn handle_oauth_configuration( // Create a temporary provider instance to handle OAuth let temp_model = ModelConfig::new("temp")?; - match create(provider_name, temp_model) { + match create(provider_name, temp_model).await { Ok(provider) => match provider.configure_oauth().await { Ok(_) => { let _ = cliclack::log::success("OAuth authentication completed successfully!"); @@ -420,7 +420,7 @@ pub async fn configure_provider_dialog() -> Result> { let config = Config::global(); // Get all available providers and their metadata - let available_providers = providers(); + let available_providers = providers().await; // Create selection items from provider metadata let provider_items: Vec<(&String, &str, &str)> = available_providers @@ -550,7 +550,7 @@ pub async fn configure_provider_dialog() -> Result> { spin.start("Attempting to fetch supported models..."); let models_res = { let temp_model_config = ModelConfig::new(&provider_meta.default_model)?; - let temp_provider = create(provider_name, temp_model_config)?; + let temp_provider = create(provider_name, temp_model_config).await?; temp_provider.fetch_supported_models().await }; spin.stop(style("Model fetch complete").green()); @@ -586,7 +586,7 @@ pub async fn configure_provider_dialog() -> Result> { .with_toolshim(toolshim_enabled) .with_toolshim_model(std::env::var("GOOSE_TOOLSHIM_OLLAMA_MODEL").ok()); - let provider = create(provider_name, model_config)?; + let provider = create(provider_name, model_config).await?; let messages = vec![Message::user().with_text("What is the weather like in San Francisco today?")]; @@ -1419,7 +1419,7 @@ pub async fn configure_tool_permissions_dialog() -> Result<(), Box> { // Create the agent let agent = Agent::new(); - let new_provider = create(&provider_name, model_config)?; + let new_provider = create(&provider_name, model_config).await?; agent.update_provider(new_provider).await?; if let Some(config) = get_extension_by_name(&selected_extension_name) { agent @@ -1688,7 +1688,7 @@ pub async fn handle_openrouter_auth() -> Result<(), Box> { } }; - match create("openrouter", model_config) { + match create("openrouter", model_config).await { Ok(provider) => { // Simple test request let test_result = provider @@ -1787,7 +1787,7 @@ pub async fn handle_tetrate_auth() -> Result<(), Box> { } }; - match create("tetrate", model_config) { + match create("tetrate", model_config).await { Ok(provider) => { // Simple test request let test_result = provider diff --git a/crates/goose-cli/src/commands/web.rs b/crates/goose-cli/src/commands/web.rs index af0f066cb2dd..f11e688c1c8f 100644 --- a/crates/goose-cli/src/commands/web.rs +++ b/crates/goose-cli/src/commands/web.rs @@ -137,7 +137,6 @@ pub async fn handle_web( // Setup logging crate::logging::setup_logging(Some("goose-web"), None)?; - // Load config and create agent just like the CLI does let config = goose::config::Config::global(); let provider_name: String = match config.get_param("GOOSE_PROVIDER") { @@ -160,7 +159,7 @@ pub async fn handle_web( // Create the agent let agent = Agent::new(); - let provider = goose::providers::create(&provider_name, model_config)?; + let provider = goose::providers::create(&provider_name, model_config).await?; agent.update_provider(provider).await?; // Load and enable extensions from config diff --git a/crates/goose-cli/src/scenario_tests/scenario_runner.rs b/crates/goose-cli/src/scenario_tests/scenario_runner.rs index 51e6a5982d5f..085d28720a95 100644 --- a/crates/goose-cli/src/scenario_tests/scenario_runner.rs +++ b/crates/goose-cli/src/scenario_tests/scenario_runner.rs @@ -180,7 +180,7 @@ where let original_env = setup_environment(config)?; - let inner_provider = create(&factory_name, ModelConfig::new(config.model_name)?)?; + let inner_provider = create(&factory_name, ModelConfig::new(config.model_name)?).await?; let test_provider = Arc::new(TestProvider::new_recording(inner_provider, &file_path)); ( diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index f5dde78bcecc..0c84f41bb983 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -242,7 +242,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { agent.add_final_output_tool(final_output_response).await; } - let new_provider = match create(&provider_name, model_config) { + let new_provider = match create(&provider_name, model_config).await { Ok(provider) => provider, Err(e) => { output::render_error(&format!( diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index cb80d7381fd5..0547282c6715 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -474,7 +474,7 @@ impl CliSession { RunMode::Plan => { let mut plan_messages = self.messages.clone(); plan_messages.push(Message::user().with_text(&content)); - let reasoner = get_reasoner()?; + let reasoner = get_reasoner().await?; self.plan_with_reasoner_model(plan_messages, reasoner) .await?; } @@ -581,7 +581,7 @@ impl CliSession { let mut plan_messages = self.messages.clone(); plan_messages.push(Message::user().with_text(&message_text)); - let reasoner = get_reasoner()?; + let reasoner = get_reasoner().await?; self.plan_with_reasoner_model(plan_messages, reasoner) .await?; } @@ -1632,7 +1632,7 @@ impl CliSession { } } -fn get_reasoner() -> Result, anyhow::Error> { +async fn get_reasoner() -> Result, anyhow::Error> { use goose::model::ModelConfig; use goose::providers::create; @@ -1660,7 +1660,7 @@ fn get_reasoner() -> Result, anyhow::Error> { let model_config = ModelConfig::new_with_context_env(model, Some("GOOSE_PLANNER_CONTEXT_LIMIT"))?; - let reasoner = create(&provider, model_config)?; + let reasoner = create(&provider, model_config).await?; Ok(reasoner) } diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs index bad6d27a2e99..a6c76009e214 100644 --- a/crates/goose-server/src/routes/agent.rs +++ b/crates/goose-server/src/routes/agent.rs @@ -273,7 +273,7 @@ async fn update_agent_provider( StatusCode::BAD_REQUEST })?; - let new_provider = create(&payload.provider, model_config).map_err(|e| { + let new_provider = create(&payload.provider, model_config).await.map_err(|e| { tracing::error!("Failed to create provider: {}", e); StatusCode::BAD_REQUEST })?; diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index 0bf758bd4111..d479147f648c 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -258,7 +258,7 @@ pub async fn read_all_config() -> Result, StatusCode> { ) )] pub async fn providers() -> Result>, StatusCode> { - let mut providers_metadata = get_providers(); + let mut providers_metadata = get_providers().await; let custom_providers_dir = goose::config::custom_providers::custom_providers_dir(); @@ -347,7 +347,7 @@ pub async fn providers() -> Result>, StatusCode> { pub async fn get_provider_models( Path(name): Path, ) -> Result>, StatusCode> { - let all = get_providers(); + let all = get_providers().await; let Some(metadata) = all.into_iter().find(|m| m.name == name) else { return Err(StatusCode::BAD_REQUEST); }; @@ -358,6 +358,7 @@ pub async fn get_provider_models( let model_config = ModelConfig::new(&metadata.default_model).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let provider = goose::providers::create(&name, model_config) + .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; match provider.fetch_supported_models().await { @@ -449,7 +450,7 @@ pub async fn get_pricing( } } else { // Get only configured providers' pricing - let providers_metadata = get_providers(); + let providers_metadata = get_providers().await; for metadata in providers_metadata { // Skip unconfigured providers if filtering @@ -684,7 +685,7 @@ pub async fn create_custom_provider( ) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - if let Err(e) = goose::providers::refresh_custom_providers() { + if let Err(e) = goose::providers::refresh_custom_providers().await { tracing::warn!("Failed to refresh custom providers after creation: {}", e); } @@ -706,7 +707,7 @@ pub async fn remove_custom_provider( goose::config::custom_providers::CustomProviderConfig::remove(&id) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - if let Err(e) = goose::providers::refresh_custom_providers() { + if let Err(e) = goose::providers::refresh_custom_providers().await { tracing::warn!("Failed to refresh custom providers after deletion: {}", e); } diff --git a/crates/goose/examples/agent.rs b/crates/goose/examples/agent.rs index 74c804c26f04..d57226a7a797 100644 --- a/crates/goose/examples/agent.rs +++ b/crates/goose/examples/agent.rs @@ -1,21 +1,20 @@ -use std::sync::Arc; - use dotenvy::dotenv; use futures::StreamExt; use goose::agents::{Agent, AgentEvent, ExtensionConfig}; use goose::config::{DEFAULT_EXTENSION_DESCRIPTION, DEFAULT_EXTENSION_TIMEOUT}; use goose::conversation::message::Message; use goose::conversation::Conversation; -use goose::providers::databricks::DatabricksProvider; +use goose::providers::create_with_named_model; +use goose::providers::databricks::DATABRICKS_DEFAULT_MODEL; #[tokio::main] async fn main() { - // Setup a model provider from env vars let _ = dotenv(); - let provider = Arc::new(DatabricksProvider::default()); + let provider = create_with_named_model("databricks", DATABRICKS_DEFAULT_MODEL) + .await + .expect("Couldn't create provider"); - // Setup an agent with the developer extension let agent = Agent::new(); let _ = agent.update_provider(provider).await; diff --git a/crates/goose/examples/databricks_oauth.rs b/crates/goose/examples/databricks_oauth.rs index 8000a97d56be..45815bef9a05 100644 --- a/crates/goose/examples/databricks_oauth.rs +++ b/crates/goose/examples/databricks_oauth.rs @@ -1,22 +1,19 @@ use anyhow::Result; use dotenvy::dotenv; use goose::conversation::message::Message; -use goose::providers::{ - base::{Provider, Usage}, - databricks::DatabricksProvider, -}; +use goose::providers::databricks::DATABRICKS_DEFAULT_MODEL; +use goose::providers::{base::Usage, create_with_named_model}; use tokio_stream::StreamExt; #[tokio::main] async fn main() -> Result<()> { - // Load environment variables from .env file dotenv().ok(); // Clear any token to force OAuth std::env::remove_var("DATABRICKS_TOKEN"); // Create the provider - let provider = DatabricksProvider::default(); + let provider = create_with_named_model("databricks", DATABRICKS_DEFAULT_MODEL).await?; // Create a simple message let message = Message::user().with_text("Tell me a short joke about programming."); diff --git a/crates/goose/examples/image_tool.rs b/crates/goose/examples/image_tool.rs index d84166d31e6c..517bebd92deb 100644 --- a/crates/goose/examples/image_tool.rs +++ b/crates/goose/examples/image_tool.rs @@ -2,12 +2,14 @@ use anyhow::Result; use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use dotenvy::dotenv; use goose::conversation::message::Message; -use goose::providers::{ - bedrock::BedrockProvider, databricks::DatabricksProvider, openai::OpenAiProvider, -}; +use goose::providers::anthropic::ANTHROPIC_DEFAULT_MODEL; +use goose::providers::create_with_named_model; +use goose::providers::databricks::DATABRICKS_DEFAULT_MODEL; +use goose::providers::openai::OPEN_AI_DEFAULT_MODEL; use rmcp::model::{CallToolRequestParam, Content, Tool}; use rmcp::object; use std::fs; +use std::sync::Arc; #[tokio::main] async fn main() -> Result<()> { @@ -15,12 +17,11 @@ async fn main() -> Result<()> { dotenv().ok(); // Create providers - let providers: Vec> = vec![ - Box::new(DatabricksProvider::default()), - Box::new(OpenAiProvider::default()), - Box::new(BedrockProvider::default()), + let providers: Vec> = vec![ + create_with_named_model("databricks", DATABRICKS_DEFAULT_MODEL).await?, + create_with_named_model("openai", OPEN_AI_DEFAULT_MODEL).await?, + create_with_named_model("anthropic", ANTHROPIC_DEFAULT_MODEL).await?, ]; - for provider in providers { // Read and encode test image let image_data = fs::read("crates/goose/examples/test_assets/test_image.png")?; diff --git a/crates/goose/src/agents/model_selector/autopilot.rs b/crates/goose/src/agents/model_selector/autopilot.rs index 7341cd1c94e8..253ddffec638 100644 --- a/crates/goose/src/agents/model_selector/autopilot.rs +++ b/crates/goose/src/agents/model_selector/autopilot.rs @@ -755,7 +755,7 @@ impl AutoPilot { self.current_role = Some(best_model.role.clone()); let model = crate::model::ModelConfig::new_or_fail(&best_model.model); - let new_provider = providers::create(&best_model.provider, model)?; + let new_provider = providers::create(&best_model.provider, model).await?; return Ok(Some(( new_provider, diff --git a/crates/goose/src/agents/subagent_handler.rs b/crates/goose/src/agents/subagent_handler.rs index 83f4cfb054fe..25e5d3f53207 100644 --- a/crates/goose/src/agents/subagent_handler.rs +++ b/crates/goose/src/agents/subagent_handler.rs @@ -5,9 +5,10 @@ use crate::{ session::SessionManager, }; use anyhow::{anyhow, Result}; -use futures::future::BoxFuture; use futures::StreamExt; use rmcp::model::{ErrorCode, ErrorData}; +use std::future::Future; +use std::pin::Pin; use tracing::debug; /// Standalone function to run a complete subagent task with output options @@ -93,7 +94,7 @@ pub async fn run_complete_subagent_task( fn get_agent_messages( text_instruction: String, task_config: TaskConfig, -) -> BoxFuture<'static, Result> { +) -> Pin> + Send>> { Box::pin(async move { let agent_manager = AgentManager::instance() .await @@ -148,7 +149,7 @@ fn get_agent_messages( Ok(AgentEvent::Message(msg)) => session_messages.push(msg), Ok(AgentEvent::McpNotification(_)) | Ok(AgentEvent::ModelChange { .. }) - | Ok(AgentEvent::HistoryReplaced(_)) => {} // Handle informational events + | Ok(AgentEvent::HistoryReplaced(_)) => {} Err(e) => { tracing::error!("Error receiving message from subagent: {}", e); break; diff --git a/crates/goose/src/execution/manager.rs b/crates/goose/src/execution/manager.rs index 76ece5c90710..46bf84b74385 100644 --- a/crates/goose/src/execution/manager.rs +++ b/crates/goose/src/execution/manager.rs @@ -88,7 +88,7 @@ impl AgentManager { if let (Some(provider_name), Some(model_name)) = (provider_name, model_name) { match ModelConfig::new(&model_name) { - Ok(model_config) => match create(&provider_name, model_config) { + Ok(model_config) => match create(&provider_name, model_config).await { Ok(provider) => { self.set_default_provider(provider).await; info!( diff --git a/crates/goose/src/lib.rs b/crates/goose/src/lib.rs index 3c38a34bb20b..bad4c081bea5 100644 --- a/crates/goose/src/lib.rs +++ b/crates/goose/src/lib.rs @@ -26,5 +26,3 @@ pub mod utils; #[cfg(test)] mod cron_test; -#[macro_use] -mod macros; diff --git a/crates/goose/src/macros.rs b/crates/goose/src/macros.rs deleted file mode 100644 index 01d853dec0ad..000000000000 --- a/crates/goose/src/macros.rs +++ /dev/null @@ -1,19 +0,0 @@ -#[macro_export] -macro_rules! impl_provider_default { - ($provider:ty) => { - impl Default for $provider { - fn default() -> Self { - let model = $crate::model::ModelConfig::new( - &<$provider as $crate::providers::base::Provider>::metadata().default_model, - ) - .expect(concat!( - "Failed to create model config for ", - stringify!($provider) - )); - - <$provider>::from_env(model) - .expect(concat!("Failed to initialize ", stringify!($provider))) - } - } - }; -} diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 09671af3555c..86b6953486d6 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -17,12 +17,11 @@ use super::formats::anthropic::{ use super::utils::{emit_debug_trace, get_model, map_http_error_to_provider_error}; use crate::config::custom_providers::CustomProviderConfig; use crate::conversation::message::Message; -use crate::impl_provider_default; use crate::model::ModelConfig; use crate::providers::retry::ProviderRetry; use rmcp::model::Tool; -const ANTHROPIC_DEFAULT_MODEL: &str = "claude-sonnet-4-0"; +pub const ANTHROPIC_DEFAULT_MODEL: &str = "claude-sonnet-4-0"; const ANTHROPIC_DEFAULT_FAST_MODEL: &str = "claude-3-7-sonnet-latest"; const ANTHROPIC_KNOWN_MODELS: &[&str] = &[ "claude-sonnet-4-0", @@ -45,10 +44,8 @@ pub struct AnthropicProvider { supports_streaming: bool, } -impl_provider_default!(AnthropicProvider); - impl AnthropicProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let model = model.with_fast(ANTHROPIC_DEFAULT_FAST_MODEL.to_string()); let config = crate::config::Config::global(); diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index 12bf93f20e3d..e261c2a8919c 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -11,7 +11,6 @@ use super::formats::openai::{create_request, get_usage, response_to_message}; use super::retry::ProviderRetry; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; use crate::conversation::message::Message; -use crate::impl_provider_default; use crate::model::ModelConfig; use rmcp::model::Tool; @@ -68,10 +67,8 @@ impl AuthProvider for AzureAuthProvider { } } -impl_provider_default!(AzureProvider); - impl AzureProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); let endpoint: String = config.get_param("AZURE_OPENAI_ENDPOINT")?; let deployment_name: String = config.get_param("AZURE_OPENAI_DEPLOYMENT_NAME")?; diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index e7246aa9ae62..7ce5db9fac3b 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -4,7 +4,6 @@ use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; use super::retry::{ProviderRetry, RetryConfig}; use crate::conversation::message::Message; -use crate::impl_provider_default; use crate::model::ModelConfig; use crate::providers::utils::emit_debug_trace; use anyhow::Result; @@ -46,7 +45,7 @@ pub struct BedrockProvider { } impl BedrockProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); // Attempt to load config and secrets to get AWS_ prefixed keys @@ -63,15 +62,14 @@ impl BedrockProvider { set_aws_env_vars(config.load_values()); set_aws_env_vars(config.load_secrets()); - let sdk_config = futures::executor::block_on(aws_config::load_from_env()); + let sdk_config = aws_config::load_from_env().await; // validate credentials or return error back up - futures::executor::block_on( - sdk_config - .credentials_provider() - .unwrap() - .provide_credentials(), - )?; + sdk_config + .credentials_provider() + .unwrap() + .provide_credentials() + .await?; let client = Client::new(&sdk_config); let retry_config = Self::load_retry_config(config); @@ -172,8 +170,6 @@ impl BedrockProvider { } } -impl_provider_default!(BedrockProvider); - #[async_trait] impl Provider for BedrockProvider { fn metadata() -> ProviderMetadata { diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index 16c36c23326b..571004028794 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -12,7 +12,6 @@ use super::errors::ProviderError; use super::utils::emit_debug_trace; use crate::config::Config; use crate::conversation::message::{Message, MessageContent}; -use crate::impl_provider_default; use crate::model::ModelConfig; use rmcp::model::Tool; @@ -27,10 +26,8 @@ pub struct ClaudeCodeProvider { model: ModelConfig, } -impl_provider_default!(ClaudeCodeProvider); - impl ClaudeCodeProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); let command: String = config .get_param("CLAUDE_CODE_COMMAND") @@ -518,16 +515,6 @@ mod tests { use super::ModelConfig; use super::*; - #[test] - fn test_claude_code_model_config() { - let provider = ClaudeCodeProvider::default(); - let config = provider.get_model_config(); - - assert_eq!(config.model_name, "claude-sonnet-4-20250514"); - // Context limit should be set by the ModelConfig - assert!(config.context_limit() > 0); - } - #[test] fn test_permission_mode_flag_construction() { // Test that in auto mode, the --permission-mode acceptEdits flag is added @@ -540,21 +527,21 @@ mod tests { std::env::remove_var("GOOSE_MODE"); } - #[test] - fn test_claude_code_invalid_model_no_fallback() { + #[tokio::test] + async fn test_claude_code_invalid_model_no_fallback() { // Test that an invalid model is kept as-is (no fallback) let invalid_model = ModelConfig::new_or_fail("invalid-model"); - let provider = ClaudeCodeProvider::from_env(invalid_model).unwrap(); + let provider = ClaudeCodeProvider::from_env(invalid_model).await.unwrap(); let config = provider.get_model_config(); assert_eq!(config.model_name, "invalid-model"); } - #[test] - fn test_claude_code_valid_model() { + #[tokio::test] + async fn test_claude_code_valid_model() { // Test that a valid model is preserved let valid_model = ModelConfig::new_or_fail("sonnet"); - let provider = ClaudeCodeProvider::from_env(valid_model).unwrap(); + let provider = ClaudeCodeProvider::from_env(valid_model).await.unwrap(); let config = provider.get_model_config(); assert_eq!(config.model_name, "sonnet"); diff --git a/crates/goose/src/providers/cursor_agent.rs b/crates/goose/src/providers/cursor_agent.rs index e76f87cfdde0..0992bd49c61a 100644 --- a/crates/goose/src/providers/cursor_agent.rs +++ b/crates/goose/src/providers/cursor_agent.rs @@ -11,7 +11,6 @@ use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::utils::emit_debug_trace; use crate::conversation::message::{Message, MessageContent}; -use crate::impl_provider_default; use crate::model::ModelConfig; use rmcp::model::Tool; @@ -26,10 +25,8 @@ pub struct CursorAgentProvider { model: ModelConfig, } -impl_provider_default!(CursorAgentProvider); - impl CursorAgentProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); let command: String = config .get_param("CURSOR_AGENT_COMMAND") @@ -450,46 +447,13 @@ mod tests { use super::ModelConfig; use super::*; - #[test] - fn test_cursor_agent_model_config() { - let provider = CursorAgentProvider::default(); - let config = provider.get_model_config(); - - assert_eq!(config.model_name, "auto"); - // Context limit should be set by the ModelConfig - assert!(config.context_limit() > 0); - } - - #[test] - fn test_cursor_agent_invalid_model_no_fallback() { - // Test that an invalid model is kept as-is (no fallback) - let invalid_model = ModelConfig::new_or_fail("invalid-model"); - let provider = CursorAgentProvider::from_env(invalid_model).unwrap(); - let config = provider.get_model_config(); - - assert_eq!(config.model_name, "invalid-model"); - } - - #[test] - fn test_cursor_agent_valid_model() { + #[tokio::test] + async fn test_cursor_agent_valid_model() { // Test that a valid model is preserved let valid_model = ModelConfig::new_or_fail("gpt-5"); - let provider = CursorAgentProvider::from_env(valid_model).unwrap(); + let provider = CursorAgentProvider::from_env(valid_model).await.unwrap(); let config = provider.get_model_config(); assert_eq!(config.model_name, "gpt-5"); } - - #[test] - fn test_filter_extensions_from_system_prompt() { - let provider = CursorAgentProvider::default(); - - let system_with_extensions = "Some system prompt\n\n# Extensions\nSome extension info\n\n# Next Section\nMore content"; - let filtered = provider.filter_extensions_from_system_prompt(system_with_extensions); - assert_eq!(filtered, "Some system prompt\n# Next Section\nMore content"); - - let system_without_extensions = "Some system prompt\n\n# Other Section\nContent"; - let filtered = provider.filter_extensions_from_system_prompt(system_without_extensions); - assert_eq!(filtered, system_without_extensions); - } } diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index f70ee14d24ad..40847bb3ed1d 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -21,7 +21,6 @@ use super::utils::{ }; use crate::config::ConfigError; use crate::conversation::message::Message; -use crate::impl_provider_default; use crate::model::ModelConfig; use crate::providers::formats::openai::{get_usage, response_to_streaming_message}; use crate::providers::retry::{ @@ -109,10 +108,8 @@ pub struct DatabricksProvider { retry_config: RetryConfig, } -impl_provider_default!(DatabricksProvider); - impl DatabricksProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); let mut host: Result = config.get_param("DATABRICKS_HOST"); diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index 084b29fea36b..be6b6eb59733 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -28,56 +28,63 @@ use super::{ use crate::config::custom_providers::{custom_providers_dir, register_custom_providers}; use crate::model::ModelConfig; use anyhow::Result; -use once_cell::sync::Lazy; +use tokio::sync::OnceCell; const DEFAULT_LEAD_TURNS: usize = 3; const DEFAULT_FAILURE_THRESHOLD: usize = 2; const DEFAULT_FALLBACK_TURNS: usize = 2; -static REGISTRY: Lazy> = Lazy::new(|| { - let registry = ProviderRegistry::new().with_providers(|registry| { - registry.register::(AnthropicProvider::from_env); - registry.register::(AzureProvider::from_env); - registry.register::(BedrockProvider::from_env); - registry.register::(ClaudeCodeProvider::from_env); - registry.register::(CursorAgentProvider::from_env); - registry.register::(DatabricksProvider::from_env); - registry.register::(GcpVertexAIProvider::from_env); - registry.register::(GeminiCliProvider::from_env); - registry.register::(GithubCopilotProvider::from_env); - registry.register::(GoogleProvider::from_env); - registry.register::(GroqProvider::from_env); - registry.register::(LiteLLMProvider::from_env); - registry.register::(OllamaProvider::from_env); - registry.register::(OpenAiProvider::from_env); - registry.register::(OpenRouterProvider::from_env); - registry.register::(SageMakerTgiProvider::from_env); - registry.register::(SnowflakeProvider::from_env); - registry.register::(TetrateProvider::from_env); - registry.register::(VeniceProvider::from_env); - registry.register::(XaiProvider::from_env); - - if let Err(e) = load_custom_providers_into_registry(registry) { - tracing::warn!("Failed to load custom providers: {}", e); - } +static REGISTRY: OnceCell> = OnceCell::const_new(); + +async fn init_registry() -> RwLock { + let mut registry = ProviderRegistry::new().with_providers(|registry| { + registry.register::(|m| Box::pin(AnthropicProvider::from_env(m))); + registry.register::(|m| Box::pin(AzureProvider::from_env(m))); + registry.register::(|m| Box::pin(BedrockProvider::from_env(m))); + registry.register::(|m| Box::pin(ClaudeCodeProvider::from_env(m))); + registry.register::(|m| Box::pin(CursorAgentProvider::from_env(m))); + registry.register::(|m| Box::pin(DatabricksProvider::from_env(m))); + registry.register::(|m| Box::pin(GcpVertexAIProvider::from_env(m))); + registry.register::(|m| Box::pin(GeminiCliProvider::from_env(m))); + registry + .register::(|m| Box::pin(GithubCopilotProvider::from_env(m))); + registry.register::(|m| Box::pin(GoogleProvider::from_env(m))); + registry.register::(|m| Box::pin(GroqProvider::from_env(m))); + registry.register::(|m| Box::pin(LiteLLMProvider::from_env(m))); + registry.register::(|m| Box::pin(OllamaProvider::from_env(m))); + registry.register::(|m| Box::pin(OpenAiProvider::from_env(m))); + registry.register::(|m| Box::pin(OpenRouterProvider::from_env(m))); + registry + .register::(|m| Box::pin(SageMakerTgiProvider::from_env(m))); + registry.register::(|m| Box::pin(SnowflakeProvider::from_env(m))); + registry.register::(|m| Box::pin(TetrateProvider::from_env(m))); + registry.register::(|m| Box::pin(VeniceProvider::from_env(m))); + registry.register::(|m| Box::pin(XaiProvider::from_env(m))); }); + if let Err(e) = load_custom_providers_into_registry(&mut registry) { + tracing::warn!("Failed to load custom providers: {}", e); + } RwLock::new(registry) -}); +} fn load_custom_providers_into_registry(registry: &mut ProviderRegistry) -> Result<()> { let config_dir = custom_providers_dir(); register_custom_providers(registry, &config_dir) } -pub fn providers() -> Vec { - REGISTRY.read().unwrap().all_metadata() +async fn get_registry() -> &'static RwLock { + REGISTRY.get_or_init(init_registry).await } -pub fn refresh_custom_providers() -> Result<()> { - let mut registry = REGISTRY.write().unwrap(); - registry.remove_custom_providers(); +pub async fn providers() -> Vec { + get_registry().await.read().unwrap().all_metadata() +} - if let Err(e) = load_custom_providers_into_registry(&mut registry) { +pub async fn refresh_custom_providers() -> Result<()> { + let registry = get_registry().await; + registry.write().unwrap().remove_custom_providers(); + + if let Err(e) = load_custom_providers_into_registry(&mut registry.write().unwrap()) { tracing::warn!("Failed to refresh custom providers: {}", e); return Err(e); } @@ -86,18 +93,36 @@ pub fn refresh_custom_providers() -> Result<()> { Ok(()) } -pub fn create(name: &str, model: ModelConfig) -> Result> { +pub async fn create(name: &str, model: ModelConfig) -> Result> { let config = crate::config::Config::global(); if let Ok(lead_model_name) = config.get_param::("GOOSE_LEAD_MODEL") { tracing::info!("Creating lead/worker provider from environment variables"); - return create_lead_worker_from_env(name, &model, &lead_model_name); + return create_lead_worker_from_env(name, &model, &lead_model_name).await; } - REGISTRY.read().unwrap().create(name, model) + let registry = get_registry().await; + let constructor = { + let guard = registry.read().unwrap(); + guard + .entries + .get(name) + .ok_or_else(|| anyhow::anyhow!("Unknown provider: {}", name))? + .constructor + .clone() + }; + constructor(model).await } -fn create_lead_worker_from_env( +pub async fn create_with_named_model( + provider_name: &str, + model_name: &str, +) -> Result> { + let config = ModelConfig::new(model_name)?; + create(provider_name, config).await +} + +async fn create_lead_worker_from_env( default_provider_name: &str, default_model: &ModelConfig, lead_model_name: &str, @@ -125,14 +150,30 @@ fn create_lead_worker_from_env( let worker_model_config = create_worker_model_config(default_model)?; - let lead_provider = REGISTRY - .read() - .unwrap() - .create(&lead_provider_name, lead_model_config)?; - let worker_provider = REGISTRY - .read() - .unwrap() - .create(default_provider_name, worker_model_config)?; + let registry = get_registry().await; + + let lead_constructor = { + let guard = registry.read().unwrap(); + guard + .entries + .get(&lead_provider_name) + .ok_or_else(|| anyhow::anyhow!("Unknown provider: {}", lead_provider_name))? + .constructor + .clone() + }; + + let worker_constructor = { + let guard = registry.read().unwrap(); + guard + .entries + .get(default_provider_name) + .ok_or_else(|| anyhow::anyhow!("Unknown provider: {}", default_provider_name))? + .constructor + .clone() + }; + + let lead_provider = lead_constructor(lead_model_config).await?; + let worker_provider = worker_constructor(worker_model_config).await?; Ok(Arc::new(LeadWorkerProvider::new_with_settings( lead_provider, @@ -205,8 +246,8 @@ mod tests { } } - #[test] - fn test_create_lead_worker_provider() { + #[tokio::test] + async fn test_create_lead_worker_provider() { let _guard = EnvVarGuard::new(&[ "GOOSE_LEAD_MODEL", "GOOSE_LEAD_PROVIDER", @@ -216,7 +257,7 @@ mod tests { _guard.set("GOOSE_LEAD_MODEL", "gpt-4o"); let gpt4mini_config = ModelConfig::new_or_fail("gpt-4o-mini"); - let result = create("openai", gpt4mini_config.clone()); + let result = create("openai", gpt4mini_config.clone()).await; match result { Ok(_) => {} @@ -229,11 +270,11 @@ mod tests { _guard.set("GOOSE_LEAD_PROVIDER", "anthropic"); _guard.set("GOOSE_LEAD_TURNS", "5"); - let _result = create("openai", gpt4mini_config); + let _result = create("openai", gpt4mini_config).await; } - #[test] - fn test_lead_model_env_vars_with_defaults() { + #[tokio::test] + async fn test_lead_model_env_vars_with_defaults() { let _guard = EnvVarGuard::new(&[ "GOOSE_LEAD_MODEL", "GOOSE_LEAD_PROVIDER", @@ -244,7 +285,7 @@ mod tests { _guard.set("GOOSE_LEAD_MODEL", "grok-3"); - let result = create("openai", ModelConfig::new_or_fail("gpt-4o-mini")); + let result = create("openai", ModelConfig::new_or_fail("gpt-4o-mini")).await; match result { Ok(_) => {} @@ -261,8 +302,8 @@ mod tests { let _result = create("openai", ModelConfig::new_or_fail("gpt-4o-mini")); } - #[test] - fn test_create_regular_provider_without_lead_config() { + #[tokio::test] + async fn test_create_regular_provider_without_lead_config() { let _guard = EnvVarGuard::new(&[ "GOOSE_LEAD_MODEL", "GOOSE_LEAD_PROVIDER", @@ -271,7 +312,7 @@ mod tests { "GOOSE_LEAD_FALLBACK_TURNS", ]); - let result = create("openai", ModelConfig::new_or_fail("gpt-4o-mini")); + let result = create("openai", ModelConfig::new_or_fail("gpt-4o-mini")).await; match result { Ok(_) => {} diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index 3fa69ea9b3e9..8c3d4c765df1 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -18,7 +18,6 @@ use crate::providers::formats::gcpvertexai::{ ModelProvider, RequestContext, }; -use crate::impl_provider_default; use crate::providers::formats::gcpvertexai::GcpLocation::Iowa; use crate::providers::gcpauth::GcpAuth; use crate::providers::retry::RetryConfig; @@ -87,23 +86,7 @@ impl GcpVertexAIProvider { /// /// # Arguments /// * `model` - Configuration for the model to be used - pub fn from_env(model: ModelConfig) -> Result { - Self::new(model) - } - - /// Creates a new provider instance with the specified model configuration. - /// - /// # Arguments - /// * `model` - Configuration for the model to be used - pub fn new(model: ModelConfig) -> Result { - futures::executor::block_on(Self::new_async(model)) - } - - /// Async implementation of new provider instance creation. - /// - /// # Arguments - /// * `model` - Configuration for the model to be used - async fn new_async(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); let project_id = config.get_param("GCP_PROJECT_ID")?; let location = Self::determine_location(config)?; @@ -445,8 +428,6 @@ impl GcpVertexAIProvider { } } -impl_provider_default!(GcpVertexAIProvider); - #[async_trait] impl Provider for GcpVertexAIProvider { /// Returns metadata about the GCP Vertex AI provider. diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index 35d59fb8b18d..22d0a04ebd46 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -10,7 +10,7 @@ use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::utils::emit_debug_trace; use crate::conversation::message::{Message, MessageContent}; -use crate::impl_provider_default; + use crate::model::ModelConfig; use rmcp::model::Role; use rmcp::model::Tool; @@ -26,10 +26,8 @@ pub struct GeminiCliProvider { model: ModelConfig, } -impl_provider_default!(GeminiCliProvider); - impl GeminiCliProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); let command: String = config .get_param("GEMINI_CLI_COMMAND") @@ -364,31 +362,21 @@ impl Provider for GeminiCliProvider { mod tests { use super::*; - #[test] - fn test_gemini_cli_model_config() { - let provider = GeminiCliProvider::default(); - let config = provider.get_model_config(); - - assert_eq!(config.model_name, "gemini-2.5-pro"); - // Context limit should be set by the ModelConfig - assert!(config.context_limit() > 0); - } - - #[test] - fn test_gemini_cli_invalid_model_no_fallback() { + #[tokio::test] + async fn test_gemini_cli_invalid_model_no_fallback() { // Test that an invalid model is kept as-is (no fallback) let invalid_model = ModelConfig::new_or_fail("invalid-model"); - let provider = GeminiCliProvider::from_env(invalid_model).unwrap(); + let provider = GeminiCliProvider::from_env(invalid_model).await.unwrap(); let config = provider.get_model_config(); assert_eq!(config.model_name, "invalid-model"); } - #[test] - fn test_gemini_cli_valid_model() { + #[tokio::test] + async fn test_gemini_cli_valid_model() { // Test that a valid model is preserved let valid_model = ModelConfig::new_or_fail(GEMINI_CLI_DEFAULT_MODEL); - let provider = GeminiCliProvider::from_env(valid_model).unwrap(); + let provider = GeminiCliProvider::from_env(valid_model).await.unwrap(); let config = provider.get_model_config(); assert_eq!(config.model_name, GEMINI_CLI_DEFAULT_MODEL); diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index 3a4a4c55508c..a6832d529a08 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -19,7 +19,7 @@ use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, I use crate::config::{Config, ConfigError}; use crate::conversation::message::Message; -use crate::impl_provider_default; + use crate::model::ModelConfig; use crate::providers::base::ConfigKey; use rmcp::model::Tool; @@ -115,10 +115,8 @@ pub struct GithubCopilotProvider { model: ModelConfig, } -impl_provider_default!(GithubCopilotProvider); - impl GithubCopilotProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let client = Client::builder() .timeout(Duration::from_secs(600)) .build()?; diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 4bf843088e34..4cb907b02371 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -3,7 +3,7 @@ use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{emit_debug_trace, handle_response_google_compat, unescape_json_values}; use crate::conversation::message::Message; -use crate::impl_provider_default; + use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; use crate::providers::formats::google::{create_request, get_usage, response_to_message}; @@ -52,10 +52,8 @@ pub struct GoogleProvider { model: ModelConfig, } -impl_provider_default!(GoogleProvider); - impl GoogleProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let model = model.with_fast(GOOGLE_DEFAULT_FAST_MODEL.to_string()); let config = crate::config::Config::global(); diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index e70a9d36aa58..dec37a278887 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -3,7 +3,6 @@ use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{get_model, handle_response_openai_compat}; use crate::conversation::message::Message; -use crate::impl_provider_default; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; @@ -30,10 +29,8 @@ pub struct GroqProvider { model: ModelConfig, } -impl_provider_default!(GroqProvider); - impl GroqProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); let api_key: String = config.get_secret("GROQ_API_KEY")?; let host: String = config diff --git a/crates/goose/src/providers/litellm.rs b/crates/goose/src/providers/litellm.rs index 7d99d28d1b9b..7926e202e194 100644 --- a/crates/goose/src/providers/litellm.rs +++ b/crates/goose/src/providers/litellm.rs @@ -10,7 +10,7 @@ use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; use crate::conversation::message::Message; -use crate::impl_provider_default; + use crate::model::ModelConfig; use rmcp::model::Tool; @@ -25,10 +25,8 @@ pub struct LiteLLMProvider { model: ModelConfig, } -impl_provider_default!(LiteLLMProvider); - impl LiteLLMProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); let api_key: String = config .get_secret("LITELLM_API_KEY") diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 64b29ed32e22..52e8ba0185b9 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -37,4 +37,4 @@ pub mod utils_universal_openai_stream; pub mod venice; pub mod xai; -pub use factory::{create, providers, refresh_custom_providers}; +pub use factory::{create, create_with_named_model, providers, refresh_custom_providers}; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 95c57a65c414..15aeca146fc7 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -6,7 +6,7 @@ use super::utils::{get_model, handle_response_openai_compat, handle_status_opena use crate::config::custom_providers::CustomProviderConfig; use crate::conversation::message::Message; use crate::conversation::Conversation; -use crate::impl_provider_default; + use crate::model::ModelConfig; use crate::providers::formats::openai::{ create_request, get_usage, response_to_message, response_to_streaming_message, @@ -43,10 +43,8 @@ pub struct OllamaProvider { supports_streaming: bool, } -impl_provider_default!(OllamaProvider); - impl OllamaProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); let host: String = config .get_param("OLLAMA_HOST") diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 3f493dcd8f6e..9c5794c1de24 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -22,7 +22,7 @@ use super::utils::{ }; use crate::config::custom_providers::CustomProviderConfig; use crate::conversation::message::Message; -use crate::impl_provider_default; + use crate::model::ModelConfig; use crate::providers::base::MessageStream; use crate::providers::formats::openai::response_to_streaming_message; @@ -56,10 +56,8 @@ pub struct OpenAiProvider { supports_streaming: bool, } -impl_provider_default!(OpenAiProvider); - impl OpenAiProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let model = model.with_fast(OPEN_AI_DEFAULT_FAST_MODEL.to_string()); let config = crate::config::Config::global(); diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index 4017fed5a5aa..85884a25a514 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -11,7 +11,7 @@ use super::utils::{ is_google_model, }; use crate::conversation::message::Message; -use crate::impl_provider_default; + use crate::model::ModelConfig; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; use rmcp::model::Tool; @@ -42,10 +42,8 @@ pub struct OpenRouterProvider { model: ModelConfig, } -impl_provider_default!(OpenRouterProvider); - impl OpenRouterProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let model = model.with_fast(OPENROUTER_DEFAULT_FAST_MODEL.to_string()); let config = crate::config::Config::global(); diff --git a/crates/goose/src/providers/provider_registry.rs b/crates/goose/src/providers/provider_registry.rs index e20bc98dd106..5a0ebe4332e4 100644 --- a/crates/goose/src/providers/provider_registry.rs +++ b/crates/goose/src/providers/provider_registry.rs @@ -1,19 +1,21 @@ use super::base::{Provider, ProviderMetadata}; use crate::model::ModelConfig; use anyhow::Result; +use futures::future::BoxFuture; use std::collections::HashMap; use std::sync::Arc; -type ProviderConstructor = Box Result> + Send + Sync>; +type ProviderConstructor = + Arc BoxFuture<'static, Result>> + Send + Sync>; -struct ProviderEntry { +pub struct ProviderEntry { metadata: ProviderMetadata, - constructor: ProviderConstructor, + pub(crate) constructor: ProviderConstructor, } #[derive(Default)] pub struct ProviderRegistry { - entries: HashMap, + pub(crate) entries: HashMap, } impl ProviderRegistry { @@ -26,7 +28,7 @@ impl ProviderRegistry { pub fn register(&mut self, constructor: F) where P: Provider + 'static, - F: Fn(ModelConfig) -> Result

+ Send + Sync + 'static, + F: Fn(ModelConfig) -> BoxFuture<'static, Result

> + Send + Sync + 'static, { let metadata = P::metadata(); let name = metadata.name.clone(); @@ -35,12 +37,17 @@ impl ProviderRegistry { name, ProviderEntry { metadata, - constructor: Box::new(move |model| Ok(Arc::new(constructor(model)?))), + constructor: Arc::new(move |model| { + let fut = constructor(model); + Box::pin(async move { + let provider = fut.await?; + Ok(Arc::new(provider) as Arc) + }) + }), }, ); } - /// create provider with custom name pub fn register_with_name( &mut self, custom_name: String, @@ -68,7 +75,13 @@ impl ProviderRegistry { custom_name, ProviderEntry { metadata: custom_metadata, - constructor: Box::new(move |model| Ok(Arc::new(constructor(model)?))), + constructor: Arc::new(move |model| { + let result = constructor(model); + Box::pin(async move { + let provider = result?; + Ok(Arc::new(provider) as Arc) + }) + }), }, ); } @@ -81,15 +94,13 @@ impl ProviderRegistry { self } - pub fn create(&self, name: &str, model: ModelConfig) -> Result> { - let _available_providers: Vec<_> = self.entries.keys().collect(); - + pub async fn create(&self, name: &str, model: ModelConfig) -> Result> { let entry = self .entries .get(name) .ok_or_else(|| anyhow::anyhow!("Unknown provider: {}", name))?; - (entry.constructor)(model) + (entry.constructor)(model).await } pub fn all_metadata(&self) -> Vec { diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index 0c65e32d75cc..5313e7705ba0 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -14,7 +14,7 @@ use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::emit_debug_trace; use crate::conversation::message::{Message, MessageContent}; -use crate::impl_provider_default; + use crate::model::ModelConfig; use chrono::Utc; use rmcp::model::Role; @@ -33,7 +33,7 @@ pub struct SageMakerTgiProvider { } impl SageMakerTgiProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); // Get SageMaker endpoint name (just the name, not full URL) @@ -54,15 +54,14 @@ impl SageMakerTgiProvider { set_aws_env_vars(config.load_values()); set_aws_env_vars(config.load_secrets()); - let aws_config = futures::executor::block_on(aws_config::load_from_env()); + let aws_config = aws_config::load_from_env().await; // Validate credentials - futures::executor::block_on( - aws_config - .credentials_provider() - .unwrap() - .provide_credentials(), - )?; + aws_config + .credentials_provider() + .unwrap() + .provide_credentials() + .await?; // Create client with longer timeout for model initialization let timeout_config = aws_config::timeout::TimeoutConfig::builder() @@ -255,8 +254,6 @@ impl SageMakerTgiProvider { } } -impl_provider_default!(SageMakerTgiProvider); - #[async_trait] impl Provider for SageMakerTgiProvider { fn metadata() -> ProviderMetadata { diff --git a/crates/goose/src/providers/snowflake.rs b/crates/goose/src/providers/snowflake.rs index f88869cc3ca6..aef57b5ae0ac 100644 --- a/crates/goose/src/providers/snowflake.rs +++ b/crates/goose/src/providers/snowflake.rs @@ -11,7 +11,7 @@ use super::retry::ProviderRetry; use super::utils::{get_model, map_http_error_to_provider_error, ImageFormat}; use crate::config::ConfigError; use crate::conversation::message::Message; -use crate::impl_provider_default; + use crate::model::ModelConfig; use rmcp::model::Tool; @@ -40,10 +40,8 @@ pub struct SnowflakeProvider { image_format: ImageFormat, } -impl_provider_default!(SnowflakeProvider); - impl SnowflakeProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); let mut host: Result = config.get_param("SNOWFLAKE_HOST"); if host.is_err() { diff --git a/crates/goose/src/providers/tetrate.rs b/crates/goose/src/providers/tetrate.rs index a12a9882047c..c4d5924b92bb 100644 --- a/crates/goose/src/providers/tetrate.rs +++ b/crates/goose/src/providers/tetrate.rs @@ -20,7 +20,7 @@ use super::utils::{ }; use crate::config::signup_tetrate::TETRATE_DEFAULT_MODEL; use crate::conversation::message::Message; -use crate::impl_provider_default; + use crate::model::ModelConfig; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; use rmcp::model::Tool; @@ -48,10 +48,8 @@ pub struct TetrateProvider { supports_streaming: bool, } -impl_provider_default!(TetrateProvider); - impl TetrateProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); let api_key: String = config.get_secret("TETRATE_API_KEY")?; // API host for LLM endpoints (/v1/chat/completions, /v1/models) diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index 9e14b04e50a9..701af251be4e 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -10,7 +10,7 @@ use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::map_http_error_to_provider_error; use crate::conversation::message::{Message, MessageContent}; -use crate::impl_provider_default; + use crate::mcp_utils::ToolResult; use crate::model::ModelConfig; use rmcp::model::{object, CallToolRequestParam, Role, Tool}; @@ -80,10 +80,8 @@ pub struct VeniceProvider { model: ModelConfig, } -impl_provider_default!(VeniceProvider); - impl VeniceProvider { - pub fn from_env(mut model: ModelConfig) -> Result { + pub async fn from_env(mut model: ModelConfig) -> Result { let config = crate::config::Config::global(); let api_key: String = config.get_secret("VENICE_API_KEY")?; let host: String = config diff --git a/crates/goose/src/providers/xai.rs b/crates/goose/src/providers/xai.rs index 202188d15a5e..bb52b0e6c09e 100644 --- a/crates/goose/src/providers/xai.rs +++ b/crates/goose/src/providers/xai.rs @@ -3,7 +3,7 @@ use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{get_model, handle_response_openai_compat}; use crate::conversation::message::Message; -use crate::impl_provider_default; + use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; @@ -44,10 +44,8 @@ pub struct XaiProvider { model: ModelConfig, } -impl_provider_default!(XaiProvider); - impl XaiProvider { - pub fn from_env(model: ModelConfig) -> Result { + pub async fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); let api_key: String = config.get_secret("XAI_API_KEY")?; let host: String = config diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index 93b696c820b2..42273229e385 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -1127,13 +1127,16 @@ async fn run_scheduled_job_internal( error: format!("Model config error: {}", e), })?; - agent_provider = create(&provider_name, model_config).map_err(|e| JobExecutionError { - job_id: job.id.clone(), - error: format!( - "Failed to create provider instance '{}': {}", - provider_name, e - ), - })?; + agent_provider = + create(&provider_name, model_config) + .await + .map_err(|e| JobExecutionError { + job_id: job.id.clone(), + error: format!( + "Failed to create provider instance '{}': {}", + provider_name, e + ), + })?; } if let Some(ref recipe_extensions) = recipe.extensions { diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 1c366024427c..157c588817e7 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -1,5 +1,3 @@ -// src/lib.rs or tests/truncate_agent_tests.rs - use std::sync::Arc; use anyhow::Result; @@ -71,19 +69,21 @@ impl ProviderType { } } - fn create_provider(&self, model_config: ModelConfig) -> Result> { + async fn create_provider(&self, model_config: ModelConfig) -> Result> { Ok(match self { - ProviderType::Azure => Arc::new(AzureProvider::from_env(model_config)?), - ProviderType::OpenAi => Arc::new(OpenAiProvider::from_env(model_config)?), - ProviderType::Anthropic => Arc::new(AnthropicProvider::from_env(model_config)?), - ProviderType::Bedrock => Arc::new(BedrockProvider::from_env(model_config)?), - ProviderType::Databricks => Arc::new(DatabricksProvider::from_env(model_config)?), - ProviderType::GcpVertexAI => Arc::new(GcpVertexAIProvider::from_env(model_config)?), - ProviderType::Google => Arc::new(GoogleProvider::from_env(model_config)?), - ProviderType::Groq => Arc::new(GroqProvider::from_env(model_config)?), - ProviderType::Ollama => Arc::new(OllamaProvider::from_env(model_config)?), - ProviderType::OpenRouter => Arc::new(OpenRouterProvider::from_env(model_config)?), - ProviderType::Xai => Arc::new(XaiProvider::from_env(model_config)?), + ProviderType::Azure => Arc::new(AzureProvider::from_env(model_config).await?), + ProviderType::OpenAi => Arc::new(OpenAiProvider::from_env(model_config).await?), + ProviderType::Anthropic => Arc::new(AnthropicProvider::from_env(model_config).await?), + ProviderType::Bedrock => Arc::new(BedrockProvider::from_env(model_config).await?), + ProviderType::Databricks => Arc::new(DatabricksProvider::from_env(model_config).await?), + ProviderType::GcpVertexAI => { + Arc::new(GcpVertexAIProvider::from_env(model_config).await?) + } + ProviderType::Google => Arc::new(GoogleProvider::from_env(model_config).await?), + ProviderType::Groq => Arc::new(GroqProvider::from_env(model_config).await?), + ProviderType::Ollama => Arc::new(OllamaProvider::from_env(model_config).await?), + ProviderType::OpenRouter => Arc::new(OpenRouterProvider::from_env(model_config).await?), + ProviderType::Xai => Arc::new(XaiProvider::from_env(model_config).await?), }) } } @@ -114,7 +114,7 @@ async fn run_truncate_test( .unwrap() .with_context_limit(Some(context_window)) .with_temperature(Some(0.0)); - let provider = provider_type.create_provider(model_config)?; + let provider = provider_type.create_provider(model_config).await?; let agent = Agent::new(); agent.update_provider(provider).await?; diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 8830f7801b8b..85449c845194 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -1,12 +1,21 @@ use anyhow::Result; use dotenvy::dotenv; use goose::conversation::message::{Message, MessageContent}; +use goose::providers::anthropic::ANTHROPIC_DEFAULT_MODEL; +use goose::providers::azure::AZURE_DEFAULT_MODEL; use goose::providers::base::Provider; +use goose::providers::bedrock::BEDROCK_DEFAULT_MODEL; +use goose::providers::create_with_named_model; +use goose::providers::databricks::DATABRICKS_DEFAULT_MODEL; use goose::providers::errors::ProviderError; -use goose::providers::{ - anthropic, azure, bedrock, databricks, google, groq, litellm, ollama, openai, openrouter, - snowflake, xai, -}; +use goose::providers::google::GOOGLE_DEFAULT_MODEL; +use goose::providers::groq::GROQ_DEFAULT_MODEL; +use goose::providers::litellm::LITELLM_DEFAULT_MODEL; +use goose::providers::ollama::OLLAMA_DEFAULT_MODEL; +use goose::providers::openai::OPEN_AI_DEFAULT_MODEL; +use goose::providers::sagemaker_tgi::SAGEMAKER_TGI_DEFAULT_MODEL; +use goose::providers::snowflake::SNOWFLAKE_DEFAULT_MODEL; +use goose::providers::xai::XAI_DEFAULT_MODEL; use rmcp::model::{AnnotateAble, Content, RawImageContent}; use rmcp::model::{CallToolRequestParam, Tool}; use rmcp::object; @@ -77,18 +86,14 @@ lazy_static::lazy_static! { static ref ENV_LOCK: Mutex<()> = Mutex::new(()); } -/// Generic test harness for any Provider implementation struct ProviderTester { provider: Arc, name: String, } impl ProviderTester { - fn new(provider: T, name: String) -> Self { - Self { - provider: Arc::new(provider), - name, - } + fn new(provider: Arc, name: String) -> Self { + Self { provider, name } } async fn test_basic_response(&self) -> Result<()> { @@ -99,14 +104,12 @@ impl ProviderTester { .complete("You are a helpful assistant.", &[message], &[]) .await?; - // For a basic response, we expect a single text response assert_eq!( response.content.len(), 1, "Expected single content item in response" ); - // Verify we got a text response assert!( matches!(response.content[0], MessageContent::Text(_)), "Expected text response" @@ -146,7 +149,6 @@ impl ProviderTester { dbg!(&response1); println!("==================="); - // Verify we got a tool request assert!( response1 .content @@ -177,7 +179,6 @@ impl ProviderTester { )]), ); - // Verify we construct a valid payload including the request/response pair for the next inference let (response2, _) = self .provider .complete( @@ -203,7 +204,6 @@ impl ProviderTester { } async fn test_context_length_exceeded_error(&self) -> Result<()> { - // Google Gemini has a really long context window let large_message_content = if self.name.to_lowercase() == "google" { "hello ".repeat(1_300_000) } else { @@ -215,7 +215,6 @@ impl ProviderTester { Message::assistant().with_text("hey! I think it's 4."), Message::user().with_text(&large_message_content), Message::assistant().with_text("heyy!!"), - // Messages before this mark should be truncated Message::user().with_text("what's the meaning of life?"), Message::assistant().with_text("the meaning of life is 42"), Message::user().with_text( @@ -223,18 +222,15 @@ impl ProviderTester { ), ]; - // Test that we get ProviderError::ContextLengthExceeded when the context window is exceeded let result = self .provider .complete("You are a helpful assistant.", &messages, &[]) .await; - // Print some debug info println!("=== {}::context_length_exceeded_error ===", self.name); dbg!(&result); println!("==================="); - // Ollama truncates by default even when the context window is exceeded if self.name.to_lowercase() == "ollama" { assert!( result.is_ok(), @@ -260,7 +256,6 @@ impl ProviderTester { use goose::conversation::message::Message; use std::fs; - // Try to read the test image let image_path = "crates/goose/examples/test_assets/test_image.png"; let image_data = match fs::read(image_path) { Ok(data) => data, @@ -281,7 +276,6 @@ impl ProviderTester { } .no_annotation(); - // Test 1: Direct image message let message_with_image = Message::user().with_image(image_content.data.clone(), image_content.mime_type.clone()); @@ -297,7 +291,6 @@ impl ProviderTester { println!("=== {}::image_content_support ===", self.name); let (response, _) = result?; println!("Image response: {:?}", response); - // Verify we got a text response assert!( response .content @@ -307,7 +300,6 @@ impl ProviderTester { ); println!("==================="); - // Test 2: Tool response with image (this should be handled gracefully) let screenshot_tool = Tool::new( "get_screenshot", "Get a screenshot of the current screen", @@ -350,7 +342,6 @@ impl ProviderTester { Ok(()) } - /// Run all provider tests async fn run_test_suite(&self) -> Result<()> { self.test_basic_response().await?; self.test_tool_usage().await?; @@ -366,74 +357,75 @@ fn load_env() { } } -/// Helper function to run a provider test with proper error handling and reporting -async fn test_provider( +async fn test_provider( name: &str, + model_name: &str, required_vars: &[&str], env_modifications: Option>>, - provider_fn: F, -) -> Result<()> -where - F: FnOnce() -> T, - T: Provider + Send + Sync + 'static, -{ - // We start off as failed, so that if the process panics it is seen as a failure +) -> Result<()> { TEST_REPORT.record_fail(name); - // Take exclusive access to environment modifications - let lock = ENV_LOCK.lock().unwrap(); + let original_env = { + let _lock = ENV_LOCK.lock().unwrap(); - load_env(); + load_env(); - // Save current environment state for required vars and modified vars - let mut original_env = HashMap::new(); - for &var in required_vars { - if let Ok(val) = std::env::var(var) { - original_env.insert(var, val); - } - } - if let Some(mods) = &env_modifications { - for &var in mods.keys() { + let mut original_env = HashMap::new(); + for &var in required_vars { if let Ok(val) = std::env::var(var) { original_env.insert(var, val); } } - } + if let Some(mods) = &env_modifications { + for &var in mods.keys() { + if let Ok(val) = std::env::var(var) { + original_env.insert(var, val); + } + } + } - // Apply any environment modifications - if let Some(mods) = &env_modifications { - for (&var, value) in mods.iter() { - match value { - Some(val) => std::env::set_var(var, val), - None => std::env::remove_var(var), + if let Some(mods) = &env_modifications { + for (&var, value) in mods.iter() { + match value { + Some(val) => std::env::set_var(var, val), + None => std::env::remove_var(var), + } } } - } - // Setup the provider - let missing_vars = required_vars.iter().any(|var| std::env::var(var).is_err()); - if missing_vars { - println!("Skipping {} tests - credentials not configured", name); - TEST_REPORT.record_skip(name); - return Ok(()); - } + let missing_vars = required_vars.iter().any(|var| std::env::var(var).is_err()); + if missing_vars { + println!("Skipping {} tests - credentials not configured", name); + TEST_REPORT.record_skip(name); + return Ok(()); + } - let provider = provider_fn(); + original_env + }; - // Restore original environment - for (&var, value) in original_env.iter() { - std::env::set_var(var, value); - } - if let Some(mods) = env_modifications { - for &var in mods.keys() { - if !original_env.contains_key(var) { - std::env::remove_var(var); + let provider = match create_with_named_model(&name.to_lowercase(), model_name).await { + Ok(p) => p, + Err(e) => { + println!("Skipping {} tests - failed to create provider: {}", name, e); + TEST_REPORT.record_skip(name); + return Ok(()); + } + }; + + { + let _lock = ENV_LOCK.lock().unwrap(); + for (&var, value) in original_env.iter() { + std::env::set_var(var, value); + } + if let Some(mods) = env_modifications { + for &var in mods.keys() { + if !original_env.contains_key(var) { + std::env::remove_var(var); + } } } } - std::mem::drop(lock); - let tester = ProviderTester::new(provider, name.to_string()); match tester.run_test_suite().await { Ok(_) => { @@ -450,26 +442,20 @@ where #[tokio::test] async fn test_openai_provider() -> Result<()> { - test_provider( - "OpenAI", - &["OPENAI_API_KEY"], - None, - openai::OpenAiProvider::default, - ) - .await + test_provider("openai", OPEN_AI_DEFAULT_MODEL, &["OPENAI_API_KEY"], None).await } #[tokio::test] async fn test_azure_provider() -> Result<()> { test_provider( "Azure", + AZURE_DEFAULT_MODEL, &[ "AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME", ], None, - azure::AzureProvider::default, ) .await } @@ -478,26 +464,23 @@ async fn test_azure_provider() -> Result<()> { async fn test_bedrock_provider_long_term_credentials() -> Result<()> { test_provider( "Bedrock", + BEDROCK_DEFAULT_MODEL, &["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], None, - bedrock::BedrockProvider::default, ) .await } #[tokio::test] async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> { - let env_mods = HashMap::from_iter([ - // Ensure to unset long-term credentials to use AWS Profile provider - ("AWS_ACCESS_KEY_ID", None), - ("AWS_SECRET_ACCESS_KEY", None), - ]); + let env_mods = + HashMap::from_iter([("AWS_ACCESS_KEY_ID", None), ("AWS_SECRET_ACCESS_KEY", None)]); test_provider( - "Bedrock AWS Profile Credentials", + "Bedrock", + BEDROCK_DEFAULT_MODEL, &["AWS_PROFILE"], Some(env_mods), - bedrock::BedrockProvider::default, ) .await } @@ -506,50 +489,30 @@ async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> { async fn test_databricks_provider() -> Result<()> { test_provider( "Databricks", + DATABRICKS_DEFAULT_MODEL, &["DATABRICKS_HOST", "DATABRICKS_TOKEN"], None, - databricks::DatabricksProvider::default, - ) - .await -} - -#[tokio::test] -async fn test_databricks_provider_oauth() -> Result<()> { - let mut env_mods = HashMap::new(); - env_mods.insert("DATABRICKS_TOKEN", None); - - test_provider( - "Databricks OAuth", - &["DATABRICKS_HOST"], - Some(env_mods), - databricks::DatabricksProvider::default, ) .await } #[tokio::test] async fn test_ollama_provider() -> Result<()> { - test_provider( - "Ollama", - &["OLLAMA_HOST"], - None, - ollama::OllamaProvider::default, - ) - .await + test_provider("Ollama", OLLAMA_DEFAULT_MODEL, &["OLLAMA_HOST"], None).await } #[tokio::test] async fn test_groq_provider() -> Result<()> { - test_provider("Groq", &["GROQ_API_KEY"], None, groq::GroqProvider::default).await + test_provider("Groq", GROQ_DEFAULT_MODEL, &["GROQ_API_KEY"], None).await } #[tokio::test] async fn test_anthropic_provider() -> Result<()> { test_provider( "Anthropic", + ANTHROPIC_DEFAULT_MODEL, &["ANTHROPIC_API_KEY"], None, - anthropic::AnthropicProvider::default, ) .await } @@ -558,31 +521,25 @@ async fn test_anthropic_provider() -> Result<()> { async fn test_openrouter_provider() -> Result<()> { test_provider( "OpenRouter", + OPEN_AI_DEFAULT_MODEL, &["OPENROUTER_API_KEY"], None, - openrouter::OpenRouterProvider::default, ) .await } #[tokio::test] async fn test_google_provider() -> Result<()> { - test_provider( - "Google", - &["GOOGLE_API_KEY"], - None, - google::GoogleProvider::default, - ) - .await + test_provider("Google", GOOGLE_DEFAULT_MODEL, &["GOOGLE_API_KEY"], None).await } #[tokio::test] async fn test_snowflake_provider() -> Result<()> { test_provider( "Snowflake", + SNOWFLAKE_DEFAULT_MODEL, &["SNOWFLAKE_HOST", "SNOWFLAKE_TOKEN"], None, - snowflake::SnowflakeProvider::default, ) .await } @@ -591,9 +548,9 @@ async fn test_snowflake_provider() -> Result<()> { async fn test_sagemaker_tgi_provider() -> Result<()> { test_provider( "SageMakerTgi", + SAGEMAKER_TGI_DEFAULT_MODEL, &["SAGEMAKER_ENDPOINT_NAME"], None, - goose::providers::sagemaker_tgi::SageMakerTgiProvider::default, ) .await } @@ -611,21 +568,14 @@ async fn test_litellm_provider() -> Result<()> { ("LITELLM_API_KEY", Some("".to_string())), ]); - test_provider( - "LiteLLM", - &[], // No required environment variables - Some(env_mods), - litellm::LiteLLMProvider::default, - ) - .await + test_provider("LiteLLM", LITELLM_DEFAULT_MODEL, &[], Some(env_mods)).await } #[tokio::test] async fn test_xai_provider() -> Result<()> { - test_provider("Xai", &["XAI_API_KEY"], None, xai::XaiProvider::default).await + test_provider("Xai", XAI_DEFAULT_MODEL, &["XAI_API_KEY"], None).await } -// Print the final test report #[ctor::dtor] fn print_test_report() { TEST_REPORT.print_summary(); diff --git a/crates/goose/tests/tetrate_streaming.rs b/crates/goose/tests/tetrate_streaming.rs index 27d5ea1f5183..784adbba68f6 100644 --- a/crates/goose/tests/tetrate_streaming.rs +++ b/crates/goose/tests/tetrate_streaming.rs @@ -13,17 +13,17 @@ use serial_test::serial; mod tetrate_streaming_tests { use super::*; - fn create_test_provider() -> Result { + async fn create_test_provider() -> Result { // Create a test provider with the default model let model_config = ModelConfig::new("claude-3-5-sonnet-latest")?; - TetrateProvider::from_env(model_config) + TetrateProvider::from_env(model_config).await } #[tokio::test] #[serial] #[ignore] // Ignore by default, run with --ignored flag when API key is available async fn test_tetrate_streaming_basic() -> Result<()> { - let provider = create_test_provider()?; + let provider = create_test_provider().await?; let messages = vec![Message::user().with_text("Count from 1 to 5, one number at a time.")]; @@ -78,7 +78,7 @@ mod tetrate_streaming_tests { #[serial] #[ignore] async fn test_tetrate_streaming_with_tools() -> Result<()> { - let provider = create_test_provider()?; + let provider = create_test_provider().await?; // Define a simple tool let weather_tool = Tool::new( @@ -140,7 +140,7 @@ mod tetrate_streaming_tests { #[serial] #[ignore] async fn test_tetrate_streaming_empty_response() -> Result<()> { - let provider = create_test_provider()?; + let provider = create_test_provider().await?; // This might result in a very short or empty response let messages = vec![Message::user().with_text("")]; @@ -169,7 +169,7 @@ mod tetrate_streaming_tests { #[serial] #[ignore] async fn test_tetrate_streaming_long_response() -> Result<()> { - let provider = create_test_provider()?; + let provider = create_test_provider().await?; let messages = vec![Message::user().with_text( "Write a detailed 3-paragraph essay about the importance of streaming in modern APIs.", @@ -230,7 +230,7 @@ mod tetrate_streaming_tests { std::env::set_var("TETRATE_API_KEY", "invalid-key-for-testing"); let model_config = ModelConfig::new("claude-3-5-sonnet-latest")?; - let provider = TetrateProvider::from_env(model_config)?; + let provider = TetrateProvider::from_env(model_config).await?; let messages = vec![Message::user().with_text("Hello")]; @@ -251,7 +251,7 @@ mod tetrate_streaming_tests { #[serial] #[ignore] async fn test_tetrate_streaming_concurrent_streams() -> Result<()> { - let provider = create_test_provider()?; + let provider = create_test_provider().await?; // Create multiple concurrent streams let messages1 = vec![Message::user().with_text("Say 'Stream 1'")];