diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index fbd7d5784b70..ef6176c789b0 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -267,210 +267,3 @@ pub async fn check_tool_permissions( extension_request_ids, ) } - -#[cfg(test)] -mod tests { - use super::*; - use crate::conversation::message::{Message, MessageContent, ToolRequest}; - use crate::mcp_utils::ToolResult; - use crate::model::ModelConfig; - use crate::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; - use crate::providers::errors::ProviderError; - use chrono::Utc; - use rmcp::model::{CallToolRequestParam, Role, Tool}; - use tempfile::NamedTempFile; - - #[derive(Clone)] - struct MockProvider { - model_config: ModelConfig, - } - - #[async_trait::async_trait] - impl Provider for MockProvider { - fn metadata() -> ProviderMetadata { - ProviderMetadata::empty() - } - - fn get_model_config(&self) -> ModelConfig { - self.model_config.clone() - } - - async fn complete_with_model( - &self, - _model_config: &ModelConfig, - _system: &str, - _messages: &[Message], - _tools: &[Tool], - ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { - Ok(( - Message::new( - Role::Assistant, - Utc::now().timestamp(), - vec![MessageContent::ToolRequest(ToolRequest { - id: "mock_tool_request".to_string(), - tool_call: ToolResult::Ok(CallToolRequestParam { - name: "platform__tool_by_tool_permission".into(), - arguments: Some(object!({ - "read_only_tools": ["file_reader", "data_fetcher"] - })), - }), - })], - ), - ProviderUsage::new("mock".to_string(), Usage::default()), - )) - } - } - - fn create_mock_provider() -> Arc { - let config = ModelConfig::new_or_fail("test-model"); - let mock_model_config = config.with_context_limit(200_000.into()); - Arc::new(MockProvider { - model_config: mock_model_config, - }) - } - - #[tokio::test] - async fn test_create_read_only_tool() { - let tool = create_read_only_tool(); - assert_eq!(tool.name, "platform__tool_by_tool_permission"); - assert!(tool - .description - .as_ref() - .is_some_and(|desc| desc.contains("read-only operation"))); - } - - #[test] - fn test_create_check_messages() { - let tool_request = ToolRequest { - id: "tool_1".to_string(), - tool_call: Ok(CallToolRequestParam { - name: "file_reader".into(), - arguments: Some(object!({"path": "/path/to/file"})), - }), - }; - - let messages = create_check_messages(vec![&tool_request]); - assert_eq!(messages.len(), 1); - let content = &messages.first().unwrap().content[0]; - if let MessageContent::Text(text_content) = content { - assert!(text_content - .text - .contains("Analyze the tool requests and list the tools")); - assert!(text_content.text.contains("file_reader")); - } else { - panic!("Expected text content"); - } - } - - #[test] - fn test_extract_read_only_tools() { - let message = Message::new( - Role::Assistant, - Utc::now().timestamp(), - vec![MessageContent::ToolRequest(ToolRequest { - id: "tool_2".to_string(), - tool_call: Ok(CallToolRequestParam { - name: "platform__tool_by_tool_permission".into(), - arguments: Some(object!({ - "read_only_tools": ["file_reader", "data_fetcher"] - })), - }), - })], - ); - - let result = extract_read_only_tools(&message); - assert!(result.is_some()); - let tools = result.unwrap(); - assert_eq!(tools, vec!["file_reader", "data_fetcher"]); - } - - #[tokio::test] - async fn test_detect_read_only_tools() { - let provider = create_mock_provider(); - let tool_request = ToolRequest { - id: "tool_1".to_string(), - tool_call: Ok(CallToolRequestParam { - name: "file_reader".into(), - arguments: Some(object!({"path": "/path/to/file"})), - }), - }; - - let result = detect_read_only_tools(provider, vec![&tool_request]).await; - assert_eq!(result, vec!["file_reader", "data_fetcher"]); - } - - #[tokio::test] - async fn test_detect_read_only_tools_empty_requests() { - let provider = create_mock_provider(); - let result = detect_read_only_tools(provider, vec![]).await; - assert!(result.is_empty()); - } - - #[tokio::test] - async fn test_check_tool_permissions_smart_approve() { - // Setup mocks - let temp_file = NamedTempFile::new().unwrap(); - let temp_path = temp_file.path(); - let mut permission_manager = PermissionManager::new(temp_path); - let provider = create_mock_provider(); - - let tools_with_readonly_annotation: HashSet = - vec!["file_reader".to_string()].into_iter().collect(); - let tools_without_annotation: HashSet = - vec!["data_fetcher".to_string()].into_iter().collect(); - - permission_manager.update_user_permission("file_reader", PermissionLevel::AlwaysAllow); - permission_manager - .update_smart_approve_permission("data_fetcher", PermissionLevel::AskBefore); - - let tool_request_1 = ToolRequest { - id: "tool_1".to_string(), - tool_call: Ok(CallToolRequestParam { - name: "file_reader".into(), - arguments: Some(object!({"path": "/path/to/file"})), - }), - }; - - let tool_request_2 = ToolRequest { - id: "tool_2".to_string(), - tool_call: Ok(CallToolRequestParam { - name: "data_fetcher".into(), - arguments: Some(object!({"url": "http://example.com"})), - }), - }; - - let enable_extension = ToolRequest { - id: "tool_3".to_string(), - tool_call: Ok(CallToolRequestParam { - name: MANAGE_EXTENSIONS_TOOL_NAME_COMPLETE.into(), - arguments: Some(object!({"action": "enable", "extension_name": "data_fetcher"})), - }), - }; - - let candidate_requests: Vec = - vec![tool_request_1, tool_request_2, enable_extension]; - - // Call the function under test - let (result, enable_extension_request_ids) = check_tool_permissions( - &candidate_requests, - "smart_approve", - tools_with_readonly_annotation, - tools_without_annotation, - &mut permission_manager, - provider, - ) - .await; - - // Validate the result - assert_eq!(result.approved.len(), 1); // file_reader should be approved - assert_eq!(result.needs_approval.len(), 2); // data_fetcher should need approval - assert_eq!(result.denied.len(), 0); // No tool should be denied in this test - assert_eq!(enable_extension_request_ids.len(), 1); - - // Ensure the right tools are in the approved and needs_approval lists - assert!(result.approved.iter().any(|req| req.id == "tool_1")); - assert!(result.needs_approval.iter().any(|req| req.id == "tool_2")); - assert!(result.needs_approval.iter().any(|req| req.id == "tool_3")); - assert!(enable_extension_request_ids.iter().any(|id| id == "tool_3")); - } -} diff --git a/crates/goose/src/providers/formats/databricks.rs b/crates/goose/src/providers/formats/databricks.rs index dd2a0f0ae565..3e585e3f0a85 100644 --- a/crates/goose/src/providers/formats/databricks.rs +++ b/crates/goose/src/providers/formats/databricks.rs @@ -1,5 +1,6 @@ use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; +use crate::providers::formats::google as gemini_schema; use crate::providers::utils::{ convert_image, detect_image_path, is_valid_function_name, load_image_file, safely_parse_json, sanitize_function_name, ImageFormat, @@ -276,9 +277,7 @@ pub fn format_tools(tools: &[Tool], model_name: &str) -> anyhow::Result Vec { .collect() } -/// Get the accepted keys for a given parent key in the JSON schema. -fn get_accepted_keys(parent_key: Option<&str>) -> Vec<&str> { +pub fn get_accepted_keys(parent_key: Option<&str>) -> Vec<&str> { match parent_key { Some("properties") => vec![ "anyOf", @@ -178,7 +177,7 @@ fn get_accepted_keys(parent_key: Option<&str>) -> Vec<&str> { /// Process a JSON map to filter out unsupported attributes, mirroring the logic /// from the official Google Gemini CLI. /// See: https://github.com/google-gemini/gemini-cli/blob/8a6509ffeba271a8e7ccb83066a9a31a5d72a647/packages/core/src/tools/tool-registry.ts#L356 -fn process_map(map: &Map, parent_key: Option<&str>) -> Value { +pub fn process_map(map: &Map, parent_key: Option<&str>) -> Value { let accepted_keys = get_accepted_keys(parent_key); let filtered_map: Map = map .iter() diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index 6a17a3d7df61..8fe23f08eeb6 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -92,29 +92,24 @@ pub fn map_http_error_to_provider_error( ProviderError::ContextLengthExceeded(payload_str) } StatusCode::BAD_REQUEST => { - let mut error_msg = "Unknown error".to_string(); + let base_msg = format!("Request failed with status: {}", status); if let Some(payload) = &payload { let payload_str = payload.to_string(); if check_context_length_exceeded(&payload_str) { ProviderError::ContextLengthExceeded(payload_str) } else { - if let Some(error) = payload.get("error") { - error_msg = error - .get("message") + ProviderError::RequestFailed( + payload + .get("error") + .and_then(|e| e.get("message")) + .or_else(|| payload.get("message")) .and_then(|m| m.as_str()) - .unwrap_or("Unknown error") - .to_string(); - } - ProviderError::RequestFailed(format!( - "Request failed with status: {}. Message: {}", - status, error_msg - )) + .map(|msg| format!("{}. Message: {}", base_msg, msg)) + .unwrap_or(base_msg), + ) } } else { - ProviderError::RequestFailed(format!( - "Request failed with status: {}. Message: {}", - status, error_msg - )) + ProviderError::RequestFailed(base_msg) } } StatusCode::TOO_MANY_REQUESTS => ProviderError::RateLimitExceeded { @@ -1101,7 +1096,7 @@ mod tests { StatusCode::BAD_REQUEST, None, ProviderError::RequestFailed( - "Request failed with status: 400 Bad Request. Message: Unknown error".to_string(), + "Request failed with status: 400 Bad Request".to_string(), ), ), // TOO_MANY_REQUESTS