diff --git a/Cargo.lock b/Cargo.lock index 683515380869..d81ee21bf521 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3442,6 +3442,7 @@ dependencies = [ "criterion", "ctor", "dotenv", + "downcast-rs", "etcetera", "fs2", "futures", diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index d2c389807217..747862ded95e 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -56,12 +56,6 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { let agent: Agent = Agent::new(); let new_provider = create(&provider_name, model_config).unwrap(); let _ = agent.update_provider(new_provider).await; - - // Initialize router tool selector if vector strategy is enabled - if let Err(e) = agent.initialize_router_tool_selector().await { - output::render_error(&format!("Failed to initialize router tool selector: {}", e)); - process::exit(1); - } // Configure tool monitoring if max_tool_repetitions is set if let Some(max_repetitions) = session_config.max_tool_repetitions { diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 97142cb30ea1..dbdb23ff69d7 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -74,6 +74,7 @@ jsonwebtoken = "9.3.1" blake3 = "1.5" fs2 = "0.4.3" futures-util = "0.3.31" +downcast-rs = "1.2" # Vector database for tool selection lancedb = "0.13" diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 46760349500e..e04f04557902 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -74,27 +74,6 @@ impl Agent { router_tool_selector: Mutex::new(None), } } - - pub async fn initialize_router_tool_selector(&self) -> Result<()> { - let router_tool_selection_strategy = std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") - .ok() - .and_then(|s| { - if s.eq_ignore_ascii_case("vector") { - Some(RouterToolSelectionStrategy::Vector) - } else { - None - } - }); - - if router_tool_selection_strategy.is_some() { - let selector = create_tool_selector(router_tool_selection_strategy) - .await - .map_err(|e| anyhow!("Failed to create tool selector: {}", e))?; - *self.router_tool_selector.lock().await = Some(selector); - } - - Ok(()) - } pub async fn configure_tool_monitor(&self, max_repetitions: Option) { let mut tool_monitor = self.tool_monitor.lock().await; @@ -292,11 +271,27 @@ impl Agent { ))] }) .map_err(|e| ToolError::ExecutionError(e.to_string())); - + // If vector tool selection is enabled, index the tools if result.is_ok() { - if let Err(e) = self.index_tools_if_vector_enabled().await { - tracing::error!("Failed to index tools after adding extension: {}", e); + if action == "disable" { + if let Err(e) = self + .index_tools_if_vector_enabled( + Some(extension_name.clone()), + Some("remove"), + false, + ) + .await + { + tracing::error!("Failed to remove tools from vector index: {}", e); + } + } else { + if let Err(e) = self + .index_tools_if_vector_enabled(Some(extension_name.clone()), Some("add"), false) + .await + { + tracing::error!("Failed to index tools: {}", e); + } } } @@ -336,12 +331,17 @@ impl Agent { extension_manager.add_extension(extension.clone()).await?; } }; - + // If vector tool selection is enabled, index the tools - if let Err(e) = self.index_tools_if_vector_enabled().await { - return Err(ExtensionError::SetupError( - format!("Failed to index tools for extension {}: {}", extension.name(), e), - )); + if let Err(e) = self + .index_tools_if_vector_enabled(Some(extension.name()), Some("add"), false) + .await + { + return Err(ExtensionError::SetupError(format!( + "Failed to index tools for extension {}: {}", + extension.name(), + e + ))); } Ok(()) @@ -398,6 +398,14 @@ impl Agent { .remove_extension(name) .await .expect("Failed to remove extension"); + + // If vector tool selection is enabled, remove tools from the index + if let Err(e) = self + .index_tools_if_vector_enabled(Some(name.to_string()), Some("remove"), false) + .await + { + tracing::error!("Failed to remove tools from vector index: {}", e); + } } pub async fn list_extensions(&self) -> Vec { @@ -621,7 +629,29 @@ impl Agent { /// Update the provider used by this agent pub async fn update_provider(&self, provider: Arc) -> Result<()> { - *self.provider.lock().await = Some(provider); + *self.provider.lock().await = Some(provider.clone()); + self.update_router_tool_selector(provider).await?; + Ok(()) + } + + async fn update_router_tool_selector(&self, provider: Arc) -> Result<()> { + let router_tool_selection_strategy = std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") + .ok() + .and_then(|s| { + if s.eq_ignore_ascii_case("vector") { + Some(RouterToolSelectionStrategy::Vector) + } else { + None + } + }); + + if router_tool_selection_strategy.is_some() { + let selector = create_tool_selector(router_tool_selection_strategy, provider) + .await + .map_err(|e| anyhow!("Failed to create tool selector: {}", e))?; + *self.router_tool_selector.lock().await = Some(selector); + } + Ok(()) } @@ -688,47 +718,112 @@ impl Agent { } } - async fn index_tools_if_vector_enabled(&self) -> Result<()> { + async fn index_tools_if_vector_enabled( + &self, + extension_name: Option, + action: Option<&str>, + reindex_all: bool, + ) -> Result<()> { + // Only proceed if vector strategy is enabled + let is_vector_enabled = std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") + .map(|s| s.eq_ignore_ascii_case("vector")) + .unwrap_or(false); + + if !is_vector_enabled { + return Ok(()); + } + let router_tool_selector = self.router_tool_selector.lock().await; if let Some(selector) = router_tool_selector.as_ref() { - // Get all tools from extension manager let extension_manager = self.extension_manager.lock().await; - let tools = extension_manager.get_prefixed_tools(None).await?; - - // Clear existing tools and re-index all - selector.clear_tools().await - .map_err(|e| anyhow!("Failed to clear tools: {}", e))?; - - // Index each tool - for tool in &tools { - let schema_str = serde_json::to_string_pretty(&tool.input_schema) - .unwrap_or_else(|_| "{}".to_string()); - - selector.index_tool( - tool.name.clone(), - tool.description.clone(), - schema_str, - ).await - .map_err(|e| anyhow!("Failed to index tool {}: {}", tool.name, e))?; + + if reindex_all { + // Clear and reindex everything + selector + .clear_tools() + .await + .map_err(|e| anyhow!("Failed to clear tools: {}", e))?; + + // Index all extension tools + let all_tools = extension_manager.get_prefixed_tools(None).await?; + for tool in &all_tools { + let schema_str = serde_json::to_string_pretty(&tool.input_schema) + .unwrap_or_else(|_| "{}".to_string()); + + selector + .index_tool(tool.name.clone(), tool.description.clone(), schema_str) + .await + .map_err(|e| anyhow!("Failed to index tool {}: {}", tool.name, e))?; + } + + // Index all frontend tools + let frontend_tools = self.frontend_tools.lock().await; + for frontend_tool in frontend_tools.values() { + let schema_str = serde_json::to_string_pretty(&frontend_tool.tool.input_schema) + .unwrap_or_else(|_| "{}".to_string()); + + selector + .index_tool( + frontend_tool.tool.name.clone(), + frontend_tool.tool.description.clone(), + schema_str, + ) + .await + .map_err(|e| { + anyhow!( + "Failed to index frontend tool {}: {}", + frontend_tool.tool.name, + e + ) + })?; + } + + tracing::info!("Reindexed all tools for vector search"); + return Ok(()); } - - // Also index frontend tools - let frontend_tools = self.frontend_tools.lock().await; - for frontend_tool in frontend_tools.values() { - let schema_str = serde_json::to_string_pretty(&frontend_tool.tool.input_schema) - .unwrap_or_else(|_| "{}".to_string()); - - selector.index_tool( - frontend_tool.tool.name.clone(), - frontend_tool.tool.description.clone(), - schema_str, - ).await - .map_err(|e| anyhow!("Failed to index frontend tool {}: {}", frontend_tool.tool.name, e))?; + + // Handle specific extension operations + if let (Some(ext_name), Some(act)) = (extension_name, action) { + match act { + "add" => { + // Get tools for specific extension + let tools = extension_manager + .get_prefixed_tools(Some(ext_name.clone())) + .await?; + for tool in &tools { + let schema_str = serde_json::to_string_pretty(&tool.input_schema) + .unwrap_or_else(|_| "{}".to_string()); + + selector + .index_tool(tool.name.clone(), tool.description.clone(), schema_str) + .await + .map_err(|e| { + anyhow!("Failed to index tool {}: {}", tool.name, e) + })?; + } + tracing::info!("Indexed {} tools for extension {}", tools.len(), ext_name); + } + "remove" => { + // Get tool names for the extension to remove them + let tools = extension_manager + .get_prefixed_tools(Some(ext_name.clone())) + .await?; + for tool in &tools { + selector.remove_tool(&tool.name).await.map_err(|e| { + anyhow!("Failed to remove tool {}: {}", tool.name, e) + })?; + } + tracing::info!("Removed {} tools for extension {}", tools.len(), ext_name); + } + _ => { + anyhow::bail!("Invalid action '{}' for tool indexing", act); + } + } + } else { + anyhow::bail!("Extension name and action required for tool indexing"); } - - tracing::info!("Indexed {} tools for vector search", tools.len() + frontend_tools.len()); } - + Ok(()) } diff --git a/crates/goose/src/agents/embeddings.rs b/crates/goose/src/agents/embeddings.rs index 5d11f909a959..3e315c5ed08a 100644 --- a/crates/goose/src/agents/embeddings.rs +++ b/crates/goose/src/agents/embeddings.rs @@ -1,49 +1,118 @@ +use crate::model::ModelConfig; +use crate::providers::base::Provider; +use crate::providers::databricks::DatabricksProvider; +use crate::providers::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse}; use anyhow::{Context, Result}; use reqwest::Client; -use serde::{Deserialize, Serialize}; use std::env; - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct EmbeddingRequest { - input: Vec, - model: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct EmbeddingResponse { - data: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct EmbeddingData { - embedding: Vec, -} +use std::sync::Arc; pub struct EmbeddingProvider { + provider: Arc, client: Client, - api_key: String, + token: String, base_url: String, - model: String, + model: ModelConfig, } impl EmbeddingProvider { - pub fn new() -> Result { - // Try to get API key from environment - let api_key = env::var("OPENAI_API_KEY") - .or_else(|_| env::var("EMBEDDING_API_KEY")) - .context("No API key found for embeddings. Set OPENAI_API_KEY or EMBEDDING_API_KEY")?; + pub fn new(provider: Arc) -> Result { + // Get configuration from the provider + let model_config = provider.get_model_config(); + let config = crate::config::Config::global(); + + // Try to use provider's embedding capability if available + if let Some(embedding_provider) = provider + .as_ref() + .as_any() + .downcast_ref::() + { + eprintln!("Using provider's native embedding capability"); + // For Databricks, we need to use a specific embedding model + let embedding_model = env::var("EMBEDDING_MODEL") + .unwrap_or_else(|_| "text-embedding-3-small".to_string()); + return Ok(Self { + provider, + client: Client::new(), + token: String::new(), // Not used when using provider's capability + base_url: String::new(), // Not used when using provider's capability + model: ModelConfig::new(embedding_model), + }); + } + + // Check if this is a Databricks provider using the provider's metadata + let is_databricks = provider.get_name() == "DatabricksProvider"; + eprintln!("Provider name: {}", provider.get_name()); + eprintln!("Is Databricks provider: {}", is_databricks); + + let (base_url, token) = if is_databricks { + let mut host: Result = + config.get_param("DATABRICKS_HOST"); + if host.is_err() { + host = config.get_secret("DATABRICKS_HOST"); + } + let host = host.context("No Databricks host found in config or secrets")?; + eprintln!("Databricks host: {}", host); + + // Check if this is an internal user + let is_internal = + host.as_str() == "https://block-lakehouse-production.cloud.databricks.com"; + eprintln!("Is internal user: {}", is_internal); + + // Get auth token + let token = if let Ok(api_key) = config.get_secret("DATABRICKS_TOKEN") { + api_key + } else { + std::env::var("DATABRICKS_TOKEN").context( + "No API key found for embeddings. Please set DATABRICKS_TOKEN environment variable", + )? + }; + + if is_internal { + let model = env::var("EMBEDDING_MODEL") + .unwrap_or_else(|_| "text-embedding-3-small".to_string()); + // Use internal Databricks endpoint + ( + format!("{}/serving-endpoints/{}/invocations", host, model), + token, + ) + } else { + // For external Databricks users, use OpenAI endpoint + let openai_key = std::env::var("OPENAI_API_KEY").context( + "No API key found for embeddings. Please set OPENAI_API_KEY environment variable", + )?; + + ( + env::var("EMBEDDING_BASE_URL") + .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()), + openai_key, + ) + } + } else { + // For other providers, use OpenAI endpoint + let token = std::env::var("OPENAI_API_KEY").context( + "No API key found for embeddings. Please set OPENAI_API_KEY environment variable", + )?; + + ( + env::var("EMBEDDING_BASE_URL") + .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()), + token, + ) + }; - let base_url = env::var("EMBEDDING_BASE_URL") - .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()); + let model = + env::var("EMBEDDING_MODEL").unwrap_or_else(|_| "text-embedding-3-small".to_string()); - let model = env::var("EMBEDDING_MODEL") - .unwrap_or_else(|_| "text-embedding-3-small".to_string()); + let log_msg = format!("Using base_url: {}, model: {}", base_url, model); + eprintln!("{}", log_msg); Ok(Self { + provider, client: Client::new(), - api_key, + token, base_url, - model, + model: ModelConfig::new(model), }) } @@ -52,15 +121,29 @@ impl EmbeddingProvider { return Ok(vec![]); } + // Try to use provider's embedding capability if available + if let Some(embedding_provider) = self + .provider + .as_ref() + .as_any() + .downcast_ref::() + { + return embedding_provider.create_embeddings(texts).await; + } + + // Fall back to default OpenAI-compatible implementation let request = EmbeddingRequest { input: texts, - model: self.model.clone(), + model: self.model.model_name.clone(), }; + // For OpenAI, we need to append /embeddings to the base URL + let url = format!("{}/embeddings", self.base_url.trim_end_matches('/')); + let response = self .client - .post(format!("{}/embeddings", self.base_url)) - .header("Authorization", format!("Bearer {}", self.api_key)) + .post(&url) + .header("Authorization", format!("Bearer {}", self.token)) .header("Content-Type", "application/json") .json(&request) .send() @@ -104,14 +187,10 @@ impl MockEmbeddingProvider { pub async fn embed(&self, texts: Vec) -> Result>> { use rand::Rng; let mut rng = rand::thread_rng(); - + Ok(texts .into_iter() - .map(|_| { - (0..1536) - .map(|_| rng.gen_range(-1.0..1.0)) - .collect() - }) + .map(|_| (0..1536).map(|_| rng.gen_range(-1.0..1.0)).collect()) .collect()) } @@ -125,11 +204,22 @@ impl MockEmbeddingProvider { } // Factory function to create appropriate embedding provider -pub async fn create_embedding_provider() -> Box { - match EmbeddingProvider::new() { - Ok(provider) => Box::new(provider), +pub async fn create_embedding_provider( + provider: Arc, +) -> Box { + eprintln!("Attempting to create embedding provider..."); + + match EmbeddingProvider::new(provider) { + Ok(provider) => { + eprintln!("Successfully created real embedding provider"); + Box::new(provider) + } Err(e) => { - tracing::warn!("Failed to create embedding provider: {}. Using mock provider.", e); + eprintln!( + "Failed to create embedding provider: {}. Using mock provider.", + e + ); + eprintln!("Initializing mock embedding provider as fallback"); Box::new(MockEmbeddingProvider::new()) } } @@ -161,4 +251,149 @@ impl EmbeddingProviderTrait for MockEmbeddingProvider { async fn embed_single(&self, text: String) -> Result> { self.embed_single(text).await } -} \ No newline at end of file +} + +// Mock provider for testing +#[derive(Debug)] +#[allow(dead_code)] +struct MockProvider; + +#[async_trait::async_trait] +impl Provider for MockProvider { + fn metadata() -> crate::providers::base::ProviderMetadata { + crate::providers::base::ProviderMetadata::new( + "mock", + "Mock Provider", + "Mock provider for testing", + "mock-model", + vec![], + "", + vec![], + ) + } + + fn get_model_config(&self) -> crate::model::ModelConfig { + crate::model::ModelConfig::new("mock-model".to_string()) + } + + async fn complete( + &self, + _system: &str, + _messages: &[crate::message::Message], + _tools: &[mcp_core::tool::Tool], + ) -> Result< + ( + crate::message::Message, + crate::providers::base::ProviderUsage, + ), + crate::providers::errors::ProviderError, + > { + unimplemented!() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + + #[tokio::test] + async fn test_mock_embedding_provider() { + let provider = MockEmbeddingProvider::new(); + + // Test single embedding + let text = "Test text for embedding".to_string(); + let embedding = provider.embed_single(text).await.unwrap(); + + // Check dimensions + assert_eq!(embedding.len(), 1536); + + // Check values are within expected range (-1.0 to 1.0) + for value in embedding { + assert!(value >= -1.0 && value <= 1.0); + } + + // Test batch embedding + let texts = vec![ + "First text".to_string(), + "Second text".to_string(), + "Third text".to_string(), + ]; + let embeddings = provider.embed(texts).await.unwrap(); + + // Check batch results + assert_eq!(embeddings.len(), 3); + for embedding in embeddings { + assert_eq!(embedding.len(), 1536); + for value in embedding { + assert!(value >= -1.0 && value <= 1.0); + } + } + } + + #[tokio::test] + async fn test_empty_input_mock_provider() { + let provider = MockEmbeddingProvider::new(); + let empty_texts: Vec = vec![]; + let result = provider.embed(empty_texts).await.unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_embedding_provider_creation() { + // Test without API key + env::remove_var("OPENAI_API_KEY"); + assert!(EmbeddingProvider::new(Arc::new(MockProvider)).is_err()); + + // Test with API key + env::set_var("OPENAI_API_KEY", "test_key"); + let provider = EmbeddingProvider::new(Arc::new(MockProvider)).unwrap(); + assert_eq!(provider.token, "test_key"); + assert_eq!(provider.model, "mock-model"); + assert_eq!(provider.base_url, "https://api.openai.com/v1"); + + // Test with custom configuration + env::set_var("EMBEDDING_MODEL", "custom-model"); + env::set_var("EMBEDDING_BASE_URL", "https://custom.api.com"); + let provider = EmbeddingProvider::new(Arc::new(MockProvider)).unwrap(); + assert_eq!(provider.model, "custom-model"); + assert_eq!(provider.base_url, "https://custom.api.com"); + + // Cleanup + env::remove_var("OPENAI_API_KEY"); + env::remove_var("EMBEDDING_MODEL"); + env::remove_var("EMBEDDING_BASE_URL"); + } + + #[tokio::test] + async fn test_create_embedding_provider_fallback() { + // Remove API key to force fallback to mock provider + env::remove_var("OPENAI_API_KEY"); + + let provider = create_embedding_provider(Arc::new(MockProvider)).await; + + // Test that we get a working provider (mock in this case) + let text = "Test text".to_string(); + let embedding = provider.embed_single(text).await.unwrap(); + assert_eq!(embedding.len(), 1536); + } + + #[tokio::test] + async fn test_mock_embedding_consistency() { + let provider = MockEmbeddingProvider::new(); + + // Test that different texts get different embeddings + let text1 = "First text".to_string(); + let text2 = "Second text".to_string(); + + let embedding1 = provider.embed_single(text1.clone()).await.unwrap(); + let embedding2 = provider.embed_single(text2.clone()).await.unwrap(); + + // Verify embeddings are different (random values should make this extremely likely) + assert!(embedding1 != embedding2); + + // Verify same text gets different embeddings (mock doesn't cache) + let embedding1_repeat = provider.embed_single(text1).await.unwrap(); + assert!(embedding1 != embedding1_repeat); + } +} diff --git a/crates/goose/src/agents/router_tool_selector.rs b/crates/goose/src/agents/router_tool_selector.rs index b0332a027ba7..2c1dc4f23e33 100644 --- a/crates/goose/src/agents/router_tool_selector.rs +++ b/crates/goose/src/agents/router_tool_selector.rs @@ -1,13 +1,14 @@ -use mcp_core::{Content, ToolError}; use mcp_core::content::TextContent; +use mcp_core::{Content, ToolError}; use async_trait::async_trait; use serde_json::Value; use std::sync::Arc; use tokio::sync::RwLock; -use crate::agents::tool_vectordb::ToolVectorDB; use crate::agents::embeddings::{create_embedding_provider, EmbeddingProviderTrait}; +use crate::agents::tool_vectordb::ToolVectorDB; +use crate::providers::base::Provider; pub enum RouterToolSelectionStrategy { Vector, @@ -16,8 +17,14 @@ pub enum RouterToolSelectionStrategy { #[async_trait] pub trait RouterToolSelector: Send + Sync { async fn select_tools(&self, params: Value) -> Result, ToolError>; - async fn index_tool(&self, tool_name: String, description: String, schema: String) -> Result<(), ToolError>; + async fn index_tool( + &self, + tool_name: String, + description: String, + schema: String, + ) -> Result<(), ToolError>; async fn clear_tools(&self) -> Result<(), ToolError>; + async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolError>; } pub struct VectorToolSelector { @@ -26,13 +33,13 @@ pub struct VectorToolSelector { } impl VectorToolSelector { - pub async fn new() -> Result { - let vector_db = ToolVectorDB::new() + pub async fn new(provider: Arc) -> Result { + let vector_db = ToolVectorDB::new(Some("tools".to_string())) .await .map_err(|e| ToolError::ExecutionError(format!("Failed to create vector DB: {}", e)))?; - - let embedding_provider = create_embedding_provider().await; - + + let embedding_provider = create_embedding_provider(provider.clone()).await; + Ok(Self { vector_db: Arc::new(RwLock::new(vector_db)), embedding_provider: Arc::new(embedding_provider), @@ -47,25 +54,25 @@ impl RouterToolSelector for VectorToolSelector { .get("query") .and_then(|v| v.as_str()) .ok_or_else(|| ToolError::InvalidParameters("Missing 'query' parameter".to_string()))?; - - let limit = params - .get("limit") - .and_then(|v| v.as_u64()) - .unwrap_or(10) as usize; - + + let limit = params.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize; + // Generate embedding for the query - let query_embedding = self.embedding_provider + let query_embedding = self + .embedding_provider .embed_single(query.to_string()) .await - .map_err(|e| ToolError::ExecutionError(format!("Failed to generate query embedding: {}", e)))?; - + .map_err(|e| { + ToolError::ExecutionError(format!("Failed to generate query embedding: {}", e)) + })?; + // Search for similar tools let vector_db = self.vector_db.read().await; let tools = vector_db .search_tools(query_embedding, limit) .await .map_err(|e| ToolError::ExecutionError(format!("Failed to search tools: {}", e)))?; - + // Convert tool records to Content let selected_tools: Vec = tools .into_iter() @@ -80,20 +87,28 @@ impl RouterToolSelector for VectorToolSelector { }) }) .collect(); - + Ok(selected_tools) } - - async fn index_tool(&self, tool_name: String, description: String, schema: String) -> Result<(), ToolError> { + + async fn index_tool( + &self, + tool_name: String, + description: String, + schema: String, + ) -> Result<(), ToolError> { // Create text to embed let text_to_embed = format!("{} {} {}", tool_name, description, schema); - + // Generate embedding - let embedding = self.embedding_provider + let embedding = self + .embedding_provider .embed_single(text_to_embed) .await - .map_err(|e| ToolError::ExecutionError(format!("Failed to generate tool embedding: {}", e)))?; - + .map_err(|e| { + ToolError::ExecutionError(format!("Failed to generate tool embedding: {}", e)) + })?; + // Index the tool let vector_db = self.vector_db.read().await; let tool_record = crate::agents::tool_vectordb::ToolRecord { @@ -102,15 +117,15 @@ impl RouterToolSelector for VectorToolSelector { schema, vector: embedding, }; - + vector_db .index_tools(vec![tool_record]) .await .map_err(|e| ToolError::ExecutionError(format!("Failed to index tool: {}", e)))?; - + Ok(()) } - + async fn clear_tools(&self) -> Result<(), ToolError> { let vector_db = self.vector_db.read().await; vector_db @@ -119,19 +134,28 @@ impl RouterToolSelector for VectorToolSelector { .map_err(|e| ToolError::ExecutionError(format!("Failed to clear tools: {}", e)))?; Ok(()) } + + async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolError> { + let vector_db = self.vector_db.read().await; + vector_db.remove_tool(tool_name).await.map_err(|e| { + ToolError::ExecutionError(format!("Failed to remove tool {}: {}", tool_name, e)) + })?; + Ok(()) + } } // Helper function to create a boxed tool selector pub async fn create_tool_selector( strategy: Option, + provider: Arc, ) -> Result, ToolError> { match strategy { Some(RouterToolSelectionStrategy::Vector) => { - let selector = VectorToolSelector::new().await?; + let selector = VectorToolSelector::new(provider).await?; Ok(Box::new(selector)) } _ => { - let selector = VectorToolSelector::new().await?; + let selector = VectorToolSelector::new(provider).await?; Ok(Box::new(selector)) } } diff --git a/crates/goose/src/agents/router_tools.rs b/crates/goose/src/agents/router_tools.rs index fe7aee59e5bb..332e1dff5aef 100644 --- a/crates/goose/src/agents/router_tools.rs +++ b/crates/goose/src/agents/router_tools.rs @@ -22,7 +22,8 @@ pub fn vector_search_tool() -> Tool { "type": "object", "required": ["query"], "properties": { - "query": {"type": "string", "description": "The query to search for the most relevant tools based on the user's messages"} + "query": {"type": "string", "description": "The query to search for the most relevant tools based on the user's messages"}, + "k": {"type": "integer", "description": "The number of tools to retrieve (defaults to 10)", "default": 10} } }), Some(ToolAnnotations { diff --git a/crates/goose/src/agents/tool_vectordb.rs b/crates/goose/src/agents/tool_vectordb.rs index 98c561eb6526..866f581d858b 100644 --- a/crates/goose/src/agents/tool_vectordb.rs +++ b/crates/goose/src/agents/tool_vectordb.rs @@ -1,15 +1,15 @@ use anyhow::{Context, Result}; -use arrow::array::{StringArray, FixedSizeListBuilder}; +use arrow::array::{FixedSizeListBuilder, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; +use etcetera::BaseStrategy; +use futures::TryStreamExt; +use lancedb::connect; use lancedb::connection::Connection; use lancedb::query::{ExecutableQuery, QueryBase}; -use lancedb::connect; use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::sync::Arc; use tokio::sync::RwLock; -use etcetera::BaseStrategy; -use futures::TryStreamExt; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolRecord { @@ -25,9 +25,9 @@ pub struct ToolVectorDB { } impl ToolVectorDB { - pub async fn new() -> Result { + pub async fn new(table_name: Option) -> Result { let db_path = Self::get_db_path()?; - + // Ensure the directory exists if let Some(parent) = db_path.parent() { tokio::fs::create_dir_all(parent) @@ -42,7 +42,7 @@ impl ToolVectorDB { let tool_db = Self { connection: Arc::new(RwLock::new(connection)), - table_name: "tools".to_string(), + table_name: table_name.unwrap_or_else(|| "tools".to_string()), }; // Initialize the table if it doesn't exist @@ -55,13 +55,13 @@ impl ToolVectorDB { let data_dir = etcetera::choose_base_strategy() .context("Failed to determine base strategy")? .data_dir(); - + Ok(data_dir.join("goose").join("tool_db")) } async fn init_table(&self) -> Result<()> { let connection = self.connection.read().await; - + // Check if table exists let table_names = connection .table_names() @@ -75,19 +75,24 @@ impl ToolVectorDB { Field::new("tool_name", DataType::Utf8, false), Field::new("description", DataType::Utf8, false), Field::new("schema", DataType::Utf8, false), - Field::new("vector", DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Float32, true)), - 1536, // OpenAI embedding dimension - ), false), + Field::new( + "vector", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + 1536, // OpenAI embedding dimension + ), + false, + ), ])); // Create empty table let tool_names = StringArray::from(vec![] as Vec<&str>); let descriptions = StringArray::from(vec![] as Vec<&str>); let schemas = StringArray::from(vec![] as Vec<&str>); - + // Create empty fixed size list array for vectors - let mut vectors_builder = FixedSizeListBuilder::new(arrow::array::Float32Builder::new(), 1536); + let mut vectors_builder = + FixedSizeListBuilder::new(arrow::array::Float32Builder::new(), 1536); let vectors = vectors_builder.finish(); let batch = arrow::record_batch::RecordBatch::try_new( @@ -104,13 +109,13 @@ impl ToolVectorDB { // LanceDB will create the table from the RecordBatch drop(connection); let connection = self.connection.write().await; - + // Use the RecordBatch directly let reader = arrow::record_batch::RecordBatchIterator::new( vec![Ok(batch)].into_iter(), schema.clone(), ); - + connection .create_table(&self.table_name, Box::new(reader)) .execute() @@ -123,7 +128,7 @@ impl ToolVectorDB { pub async fn clear_tools(&self) -> Result<()> { let connection = self.connection.write().await; - + // Drop the table if it exists let table_names = connection .table_names() @@ -139,10 +144,10 @@ impl ToolVectorDB { } drop(connection); - + // Reinitialize the table self.init_table().await?; - + Ok(()) } @@ -154,8 +159,9 @@ impl ToolVectorDB { let tool_names: Vec<&str> = tools.iter().map(|t| t.tool_name.as_str()).collect(); let descriptions: Vec<&str> = tools.iter().map(|t| t.description.as_str()).collect(); let schemas: Vec<&str> = tools.iter().map(|t| t.schema.as_str()).collect(); - - let vectors_data: Vec>>> = tools.iter() + + let vectors_data: Vec>>> = tools + .iter() .map(|t| Some(t.vector.iter().map(|&v| Some(v)).collect())) .collect(); @@ -163,17 +169,22 @@ impl ToolVectorDB { Field::new("tool_name", DataType::Utf8, false), Field::new("description", DataType::Utf8, false), Field::new("schema", DataType::Utf8, false), - Field::new("vector", DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Float32, true)), - 1536, - ), false), + Field::new( + "vector", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + 1536, + ), + false, + ), ])); let tool_names_array = StringArray::from(tool_names); let descriptions_array = StringArray::from(descriptions); let schemas_array = StringArray::from(schemas); // Build vectors array - let mut vectors_builder = FixedSizeListBuilder::new(arrow::array::Float32Builder::new(), 1536); + let mut vectors_builder = + FixedSizeListBuilder::new(arrow::array::Float32Builder::new(), 1536); for vector_opt in vectors_data { if let Some(vector) = vector_opt { let values = vectors_builder.values(); @@ -214,7 +225,7 @@ impl ToolVectorDB { vec![Ok(batch)].into_iter(), schema.clone(), ); - + table .add(Box::new(reader)) .execute() @@ -224,9 +235,13 @@ impl ToolVectorDB { Ok(()) } - pub async fn search_tools(&self, query_vector: Vec, limit: usize) -> Result> { + pub async fn search_tools( + &self, + query_vector: Vec, + limit: usize, + ) -> Result> { let connection = self.connection.read().await; - + let table = connection .open_table(&self.table_name) .execute() @@ -242,7 +257,7 @@ impl ToolVectorDB { .context("Failed to execute vector search")?; let batches: Vec<_> = results.try_collect().await?; - + let mut tools = Vec::new(); for batch in batches { let tool_names = batch @@ -275,9 +290,26 @@ impl ToolVectorDB { }); } } - Ok(tools) } + + pub async fn remove_tool(&self, tool_name: &str) -> Result<()> { + let connection = self.connection.read().await; + + let table = connection + .open_table(&self.table_name) + .execute() + .await + .context("Failed to open tools table")?; + + // Delete records matching the tool name + table + .delete(&format!("tool_name = '{}'", tool_name)) + .await + .context("Failed to delete tool")?; + + Ok(()) + } } #[cfg(test)] @@ -286,7 +318,108 @@ mod tests { #[tokio::test] async fn test_tool_vectordb_creation() { - let db = ToolVectorDB::new().await.unwrap(); - assert_eq!(db.table_name, "tools"); + let db = ToolVectorDB::new(Some("test_tools_vectordb_creation".to_string())) + .await + .unwrap(); + db.clear_tools().await.unwrap(); + assert_eq!(db.table_name, "test_tools_vectordb_creation"); + } + + #[tokio::test] + async fn test_tool_vectordb_operations() -> Result<()> { + // Create a new database instance with a unique table name + let db = ToolVectorDB::new(Some("test_tool_vectordb_operations".to_string())).await?; + + // Clear any existing tools + db.clear_tools().await?; + + // Create test tool records + let test_tools = vec![ + ToolRecord { + tool_name: "test_tool_1".to_string(), + description: "A test tool for reading files".to_string(), + schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"# + .to_string(), + vector: vec![0.1; 1536], // Mock embedding vector + }, + ToolRecord { + tool_name: "test_tool_2".to_string(), + description: "A test tool for writing files".to_string(), + schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"# + .to_string(), + vector: vec![0.2; 1536], // Different mock embedding vector + }, + ]; + + // Index the test tools + db.index_tools(test_tools).await?; + + // Search for tools using a query vector similar to test_tool_1 + let query_vector = vec![0.1; 1536]; + let results = db.search_tools(query_vector, 2).await?; + + // Verify results + assert_eq!(results.len(), 2, "Should find both tools"); + assert_eq!( + results[0].tool_name, "test_tool_1", + "First result should be test_tool_1" + ); + assert_eq!( + results[1].tool_name, "test_tool_2", + "Second result should be test_tool_2" + ); + + Ok(()) } -} \ No newline at end of file + + #[tokio::test] + async fn test_empty_db() -> Result<()> { + // Create a new database instance with a unique table name + let db = ToolVectorDB::new(Some("test_empty_db".to_string())).await?; + + // Clear any existing tools + db.clear_tools().await?; + + // Search in empty database + let query_vector = vec![0.1; 1536]; + let results = db.search_tools(query_vector, 2).await?; + + // Verify no results returned + assert_eq!(results.len(), 0, "Empty database should return no results"); + + Ok(()) + } + + #[tokio::test] + async fn test_tool_deletion() -> Result<()> { + // Create a new database instance with a unique table name + let db = ToolVectorDB::new(Some("test_tool_deletion".to_string())).await?; + + // Clear any existing tools + db.clear_tools().await?; + + // Create and index a test tool + let test_tool = ToolRecord { + tool_name: "test_tool_to_delete".to_string(), + description: "A test tool that will be deleted".to_string(), + schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#.to_string(), + vector: vec![0.1; 1536], + }; + + db.index_tools(vec![test_tool]).await?; + + // Verify tool exists + let query_vector = vec![0.1; 1536]; + let results = db.search_tools(query_vector.clone(), 1).await?; + assert_eq!(results.len(), 1, "Tool should exist before deletion"); + + // Delete the tool + db.remove_tool("test_tool_to_delete").await?; + + // Verify tool is gone + let results = db.search_tools(query_vector.clone(), 1).await?; + assert_eq!(results.len(), 0, "Tool should be deleted"); + + Ok(()) + } +} diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index aa4bea4e12f5..22219db6f9d4 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -147,15 +147,25 @@ impl Usage { } use async_trait::async_trait; +use downcast_rs::{impl_downcast, Downcast}; /// Base trait for AI providers (OpenAI, Anthropic, etc) #[async_trait] -pub trait Provider: Send + Sync { +pub trait Provider: Send + Sync + Downcast { /// Get the metadata for this provider type fn metadata() -> ProviderMetadata where Self: Sized; + /// Get the name of this provider + fn get_name(&self) -> String { + std::any::type_name::() + .split("::") + .last() + .unwrap_or("unknown") + .to_string() + } + /// Generate the next message using the configured model and other parameters /// /// # Arguments @@ -185,6 +195,8 @@ pub trait Provider: Send + Sync { } } +impl_downcast!(Provider); + #[cfg(test)] mod tests { use super::*; diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 189d5b97ca05..ca636d094ed0 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -1,11 +1,5 @@ -use anyhow::Result; -use async_trait::async_trait; -use reqwest::{Client, StatusCode}; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::time::Duration; - use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse}; use super::errors::ProviderError; use super::formats::databricks::{create_request, get_usage, response_to_message}; use super::oauth; @@ -14,8 +8,16 @@ use crate::config::ConfigError; use crate::message::Message; use crate::model::ModelConfig; use mcp_core::tool::Tool; +use serde_json::json; use url::Url; +use anyhow::Result; +use async_trait::async_trait; +use reqwest::{Client, StatusCode}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::time::Duration; + const DEFAULT_CLIENT_ID: &str = "databricks-cli"; const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020"; // "offline_access" scope is used to request an OAuth 2.0 Refresh Token @@ -166,7 +168,17 @@ impl DatabricksProvider { async fn post(&self, payload: Value) -> Result { let base_url = Url::parse(&self.host) .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let path = format!("serving-endpoints/{}/invocations", self.model.model_name); + + // Check if this is an embedding request by looking at the payload structure + let is_embedding = payload.get("input").is_some() && !payload.get("messages").is_some(); + let path = if is_embedding { + // For embeddings, use the embeddings endpoint + format!("serving-endpoints/{}/invocations", "text-embedding-3-small") + } else { + // For chat completions, use the model name in the path + format!("serving-endpoints/{}/invocations", self.model.model_name) + }; + let url = base_url.join(&path).map_err(|e| { ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) })?; @@ -176,6 +188,7 @@ impl DatabricksProvider { .client .post(url) .header("Authorization", auth_header) + .header("Content-Type", "application/json") .json(&payload) .send() .await?; @@ -184,7 +197,7 @@ impl DatabricksProvider { let payload: Option = response.json().await.ok(); match status { - StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ), + StatusCode::OK => payload.ok_or_else(|| ProviderError::RequestFailed("Response body is not valid JSON".to_string())), StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ Status: {}. Response: {:?}", status, payload))) @@ -296,3 +309,39 @@ impl Provider for DatabricksProvider { Ok((message, ProviderUsage::new(model, usage))) } } + +#[async_trait] +impl EmbeddingCapable for DatabricksProvider { + async fn create_embeddings(&self, texts: Vec) -> Result>> { + if texts.is_empty() { + return Ok(vec![]); + } + + // Create request in Databricks format for embeddings + let request = json!({ + "input": texts, + "instruction": "Represent this sentence for searching relevant passages:" + }); + + let response = self.post(request).await?; + // eprintln!("Databricks embedding response: {}", serde_json::to_string_pretty(&response)?); + + // Extract embeddings from Databricks response format + let embeddings = response["data"] + .as_array() + .ok_or_else(|| anyhow::anyhow!("Invalid response format: missing data array"))? + .iter() + .map(|item| { + item["embedding"] + .as_array() + .ok_or_else(|| anyhow::anyhow!("Invalid embedding format"))? + .iter() + .map(|v| v.as_f64().map(|f| f as f32)) + .collect::>>() + .ok_or_else(|| anyhow::anyhow!("Invalid embedding values")) + }) + .collect::>>>()?; + + Ok(embeddings) + } +} diff --git a/crates/goose/src/providers/embedding.rs b/crates/goose/src/providers/embedding.rs new file mode 100644 index 000000000000..469d22aeb57e --- /dev/null +++ b/crates/goose/src/providers/embedding.rs @@ -0,0 +1,24 @@ +use anyhow::Result; +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EmbeddingRequest { + pub input: Vec, + pub model: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EmbeddingResponse { + pub data: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EmbeddingData { + pub embedding: Vec, +} + +#[async_trait] +pub trait EmbeddingCapable { + async fn create_embeddings(&self, texts: Vec) -> Result>>; +} diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 14b810d880dc..c91e43c6267f 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -4,6 +4,7 @@ pub mod azureauth; pub mod base; pub mod bedrock; pub mod databricks; +pub mod embedding; pub mod errors; mod factory; pub mod formats;