Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use crate::agents::router_tool_selector::{
create_tool_selector, RouterToolSelectionStrategy, RouterToolSelector,
};
use crate::agents::router_tools::ROUTER_VECTOR_SEARCH_TOOL_NAME;
use crate::agents::tool_vectordb::generate_table_id;
use crate::agents::types::SessionConfig;
use crate::agents::types::{FrontendTool, ToolResultReceiver};
use mcp_core::{
Expand Down Expand Up @@ -192,7 +193,9 @@ impl Agent {
"Frontend tool execution required".to_string(),
))
} else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME {
eprintln!("[DEBUG] Received tool call: {:?}", tool_call);
let router_tool_selector = self.router_tool_selector.lock().await;
eprintln!("[DEBUG] Router tool selector: ");
if let Some(selector) = router_tool_selector.as_ref() {
selector.select_tools(tool_call.arguments.clone()).await
} else {
Expand Down Expand Up @@ -670,24 +673,49 @@ impl Agent {
.get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY")
.unwrap_or_else(|_| "default".to_string());

eprintln!(
"[DEBUG] Router tool selection strategy from config: {}",
router_tool_selection_strategy
);

let strategy = match router_tool_selection_strategy.to_lowercase().as_str() {
"vector" => Some(RouterToolSelectionStrategy::Vector),
_ => None,
};

if let Some(strategy) = strategy {
let selector = create_tool_selector(Some(strategy), provider)
.await
.map_err(|e| anyhow!("Failed to create tool selector: {}", e))?;
eprintln!("[DEBUG] Parsed strategy: {:?}", strategy);

// Clear tools from the vector database
selector
.clear_tools()
if let Some(strategy) = strategy {
eprintln!("[DEBUG] Creating tool selector with vector strategy...");
let table_name = generate_table_id();
eprintln!("[DEBUG] Table name: {}", table_name);
let selector = create_tool_selector(Some(strategy), provider, table_name)
.await
.map_err(|e| anyhow!("Failed to clear tools: {}", e))?;
.map_err(|e| {
eprintln!("[DEBUG] Failed to create tool selector: {}", e);
anyhow!("Failed to create tool selector: {}", e)
})?;

eprintln!("[DEBUG] Clearing existing tools from vector database...");
// // Clear tools from the vector database
// selector
// .clear_tools()
// .await
// .map_err(|e| {
// eprintln!("[DEBUG] Failed to clear tools: {}", e);
// anyhow!("Failed to clear tools: {}", e)
// })?;

eprintln!("[DEBUG] Setting router tool selector...");
*self.router_tool_selector.lock().await = Some(selector);

eprintln!("[DEBUG] Indexing platform tools...");
self.index_platform_tools().await?;
eprintln!("[DEBUG] Router tool selector initialization complete");
} else {
eprintln!(
"[DEBUG] No vector strategy selected, skipping router tool selector initialization"
);
}

Ok(())
Expand Down
4 changes: 2 additions & 2 deletions crates/goose/src/agents/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,14 @@ mod tests {
env::set_var("OPENAI_API_KEY", "test_key");
let provider = EmbeddingProvider::new(Arc::new(MockProvider)).unwrap();
assert_eq!(provider.token, "test_key");
assert_eq!(provider.model, "mock-model");
assert_eq!(provider.model.model_name, "mock-model");
assert_eq!(provider.base_url, "https://api.openai.com/v1");

// Test with custom configuration
env::set_var("EMBEDDING_MODEL", "custom-model");
env::set_var("EMBEDDING_BASE_URL", "https://custom.api.com");
let provider = EmbeddingProvider::new(Arc::new(MockProvider)).unwrap();
assert_eq!(provider.model, "custom-model");
assert_eq!(provider.model.model_name, "custom-model");
assert_eq!(provider.base_url, "https://custom.api.com");

// Cleanup
Expand Down
38 changes: 30 additions & 8 deletions crates/goose/src/agents/router_tool_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::agents::embeddings::{create_embedding_provider, EmbeddingProviderTrai
use crate::agents::tool_vectordb::ToolVectorDB;
use crate::providers::base::Provider;

#[derive(Debug, Clone)]
pub enum RouterToolSelectionStrategy {
Vector,
}
Expand All @@ -37,8 +38,8 @@ pub struct VectorToolSelector {
}

impl VectorToolSelector {
pub async fn new(provider: Arc<dyn Provider>) -> Result<Self, ToolError> {
let vector_db = ToolVectorDB::new(Some("tools".to_string()))
pub async fn new(provider: Arc<dyn Provider>, table_name: String) -> Result<Self, ToolError> {
let vector_db = ToolVectorDB::new(Some(table_name))
.await
.map_err(|e| ToolError::ExecutionError(format!("Failed to create vector DB: {}", e)))?;

Expand All @@ -55,28 +56,44 @@ impl VectorToolSelector {
#[async_trait]
impl RouterToolSelector for VectorToolSelector {
async fn select_tools(&self, params: Value) -> Result<Vec<Content>, ToolError> {
eprintln!("[DEBUG] Received params: {:?}", params);

let query = params
.get("query")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::InvalidParameters("Missing 'query' parameter".to_string()))?;

let limit = params.get("limit").and_then(|v| v.as_u64()).unwrap_or(20) as usize;
eprintln!("[DEBUG] Extracted query: {}", query);

let k = params.get("k").and_then(|v| v.as_u64()).unwrap_or(5) as usize;
eprintln!("[DEBUG] Using k value: {}", k);

// Generate embedding for the query
eprintln!("[DEBUG] Generating embedding for query...");
let query_embedding = self
.embedding_provider
.embed_single(query.to_string())
.await
.map_err(|e| {
eprintln!("[DEBUG] Embedding generation failed: {}", e);
ToolError::ExecutionError(format!("Failed to generate query embedding: {}", e))
})?;
eprintln!("[DEBUG] Successfully generated embedding");

// Search for similar tools
eprintln!("[DEBUG] Starting vector search...");
let vector_db = self.vector_db.read().await;
let tools = vector_db
.search_tools(query_embedding, limit)
.search_tools(query_embedding, k)
.await
.map_err(|e| ToolError::ExecutionError(format!("Failed to search tools: {}", e)))?;
.map_err(|e| {
eprintln!("[DEBUG] Vector search failed: {}", e);
ToolError::ExecutionError(format!("Failed to search tools: {}", e))
})?;
eprintln!(
"[DEBUG] Vector search completed, found {} tools",
tools.len()
);

// Convert tool records to Content
let selected_tools: Vec<Content> = tools
Expand All @@ -93,6 +110,10 @@ impl RouterToolSelector for VectorToolSelector {
})
.collect();

eprintln!(
"[DEBUG] Successfully converted {} tools to Content",
selected_tools.len()
);
Ok(selected_tools)
}

Expand Down Expand Up @@ -132,7 +153,7 @@ impl RouterToolSelector for VectorToolSelector {
}

async fn clear_tools(&self) -> Result<(), ToolError> {
let vector_db = self.vector_db.read().await;
let vector_db = self.vector_db.write().await;
vector_db
.clear_tools()
.await
Expand Down Expand Up @@ -167,14 +188,15 @@ impl RouterToolSelector for VectorToolSelector {
pub async fn create_tool_selector(
strategy: Option<RouterToolSelectionStrategy>,
provider: Arc<dyn Provider>,
table_name: String,
) -> Result<Box<dyn RouterToolSelector>, ToolError> {
match strategy {
Some(RouterToolSelectionStrategy::Vector) => {
let selector = VectorToolSelector::new(provider).await?;
let selector = VectorToolSelector::new(provider, table_name).await?;
Ok(Box::new(selector))
}
None => {
let selector = VectorToolSelector::new(provider).await?;
let selector = VectorToolSelector::new(provider, table_name).await?;
Ok(Box::new(selector))
}
}
Expand Down
48 changes: 25 additions & 23 deletions crates/goose/src/agents/tool_vectordb.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use anyhow::{Context, Result};
use arrow::array::{FixedSizeListBuilder, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use etcetera::BaseStrategy;
use chrono::Local;
use etcetera::base_strategy::{BaseStrategy, Xdg};
use futures::TryStreamExt;
use lancedb::connect;
use lancedb::connection::Connection;
Expand Down Expand Up @@ -52,7 +53,7 @@ impl ToolVectorDB {
}

fn get_db_path() -> Result<PathBuf> {
let data_dir = etcetera::choose_base_strategy()
let data_dir = Xdg::new()
.context("Failed to determine base strategy")?
.data_dir();

Expand Down Expand Up @@ -120,7 +121,9 @@ impl ToolVectorDB {
.create_table(&self.table_name, Box::new(reader))
.execute()
.await
.context("Failed to create tools table")?;
.map_err(|e| {
anyhow::anyhow!("Failed to create tools table '{}': {}", self.table_name, e)
})?;
}

Ok(())
Expand All @@ -129,23 +132,23 @@ impl ToolVectorDB {
pub async fn clear_tools(&self) -> Result<()> {
let connection = self.connection.write().await;

// Drop the table if it exists
let table_names = connection
.table_names()
.execute()
.await
.context("Failed to list tables")?;

if table_names.contains(&self.table_name) {
connection
.drop_table(&self.table_name)
.await
.context("Failed to drop tools table")?;
// Try to open the table first
match connection.open_table(&self.table_name).execute().await {
Ok(table) => {
// Delete all records instead of dropping the table
table
.delete("1=1") // This will match all records
.await
.context("Failed to delete all records")?;
}
Err(_) => {
// If table doesn't exist, that's fine - we'll create it
}
}

drop(connection);

// Reinitialize the table
// Ensure table exists with correct schema
self.init_table().await?;

Ok(())
Expand Down Expand Up @@ -235,11 +238,7 @@ impl ToolVectorDB {
Ok(())
}

pub async fn search_tools(
&self,
query_vector: Vec<f32>,
limit: usize,
) -> Result<Vec<ToolRecord>> {
pub async fn search_tools(&self, query_vector: Vec<f32>, k: usize) -> Result<Vec<ToolRecord>> {
let connection = self.connection.read().await;

let table = connection
Expand All @@ -251,7 +250,7 @@ impl ToolVectorDB {
let results = table
.vector_search(query_vector)
.context("Failed to create vector search")?
.limit(limit)
.limit(k)
.execute()
.await
.context("Failed to execute vector search")?;
Expand Down Expand Up @@ -292,7 +291,6 @@ impl ToolVectorDB {
for i in 0..batch.num_rows() {
let tool_name = tool_names.value(i).to_string();
let distance = distances.value(i);
eprintln!("Tool: {}, Distance Score: {}", tool_name, distance);

tools.push(ToolRecord {
tool_name,
Expand Down Expand Up @@ -324,6 +322,10 @@ impl ToolVectorDB {
}
}

pub fn generate_table_id() -> String {
Local::now().format("%Y%m%d_%H%M%S").to_string()
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/providers/databricks.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse};
use super::embedding::EmbeddingCapable;
use super::errors::ProviderError;
use super::formats::databricks::{create_request, get_usage, response_to_message};
use super::oauth;
Expand Down
3 changes: 3 additions & 0 deletions ui/desktop/src/components/settings_v2/SettingsView.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import type { View, ViewOptions } from '../../App';
import ExtensionsSection from './extensions/ExtensionsSection';
import ModelsSection from './models/ModelsSection';
import { ModeSection } from './mode/ModeSection';
import { ToolSelectionStrategySection } from './tool_selection_strategy/ToolSelectionStrategySection';
import SessionSharingSection from './sessions/SessionSharingSection';
import { ResponseStylesSection } from './response_styles/ResponseStylesSection';
import { ExtensionConfig } from '../../api';
Expand Down Expand Up @@ -50,6 +51,8 @@ export default function SettingsView({
<SessionSharingSection />
{/* Response Styles */}
<ResponseStylesSection />
{/* Tool Selection Strategy */}
<ToolSelectionStrategySection setView={setView} />
</div>
</div>
</div>
Expand Down
Loading