Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 0 additions & 207 deletions crates/goose/src/permission/permission_judge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,210 +267,3 @@ pub async fn check_tool_permissions(
extension_request_ids,
)
}

#[cfg(test)]
mod tests {
use super::*;
use crate::conversation::message::{Message, MessageContent, ToolRequest};
use crate::mcp_utils::ToolResult;
use crate::model::ModelConfig;
use crate::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage};
use crate::providers::errors::ProviderError;
use chrono::Utc;
use rmcp::model::{CallToolRequestParam, Role, Tool};
use tempfile::NamedTempFile;

#[derive(Clone)]
struct MockProvider {
model_config: ModelConfig,
}

#[async_trait::async_trait]
impl Provider for MockProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::empty()
}

fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}

async fn complete_with_model(
&self,
_model_config: &ModelConfig,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message::new(
Role::Assistant,
Utc::now().timestamp(),
vec![MessageContent::ToolRequest(ToolRequest {
id: "mock_tool_request".to_string(),
tool_call: ToolResult::Ok(CallToolRequestParam {
name: "platform__tool_by_tool_permission".into(),
arguments: Some(object!({
"read_only_tools": ["file_reader", "data_fetcher"]
})),
}),
})],
),
ProviderUsage::new("mock".to_string(), Usage::default()),
))
}
}

fn create_mock_provider() -> Arc<dyn Provider> {
let config = ModelConfig::new_or_fail("test-model");
let mock_model_config = config.with_context_limit(200_000.into());
Arc::new(MockProvider {
model_config: mock_model_config,
})
}

#[tokio::test]
async fn test_create_read_only_tool() {
let tool = create_read_only_tool();
assert_eq!(tool.name, "platform__tool_by_tool_permission");
assert!(tool
.description
.as_ref()
.is_some_and(|desc| desc.contains("read-only operation")));
}

#[test]
fn test_create_check_messages() {
let tool_request = ToolRequest {
id: "tool_1".to_string(),
tool_call: Ok(CallToolRequestParam {
name: "file_reader".into(),
arguments: Some(object!({"path": "/path/to/file"})),
}),
};

let messages = create_check_messages(vec![&tool_request]);
assert_eq!(messages.len(), 1);
let content = &messages.first().unwrap().content[0];
if let MessageContent::Text(text_content) = content {
assert!(text_content
.text
.contains("Analyze the tool requests and list the tools"));
assert!(text_content.text.contains("file_reader"));
} else {
panic!("Expected text content");
}
}

#[test]
fn test_extract_read_only_tools() {
let message = Message::new(
Role::Assistant,
Utc::now().timestamp(),
vec![MessageContent::ToolRequest(ToolRequest {
id: "tool_2".to_string(),
tool_call: Ok(CallToolRequestParam {
name: "platform__tool_by_tool_permission".into(),
arguments: Some(object!({
"read_only_tools": ["file_reader", "data_fetcher"]
})),
}),
})],
);

let result = extract_read_only_tools(&message);
assert!(result.is_some());
let tools = result.unwrap();
assert_eq!(tools, vec!["file_reader", "data_fetcher"]);
}

#[tokio::test]
async fn test_detect_read_only_tools() {
let provider = create_mock_provider();
let tool_request = ToolRequest {
id: "tool_1".to_string(),
tool_call: Ok(CallToolRequestParam {
name: "file_reader".into(),
arguments: Some(object!({"path": "/path/to/file"})),
}),
};

let result = detect_read_only_tools(provider, vec![&tool_request]).await;
assert_eq!(result, vec!["file_reader", "data_fetcher"]);
}

#[tokio::test]
async fn test_detect_read_only_tools_empty_requests() {
let provider = create_mock_provider();
let result = detect_read_only_tools(provider, vec![]).await;
assert!(result.is_empty());
}

#[tokio::test]
async fn test_check_tool_permissions_smart_approve() {
// Setup mocks
let temp_file = NamedTempFile::new().unwrap();
let temp_path = temp_file.path();
let mut permission_manager = PermissionManager::new(temp_path);
let provider = create_mock_provider();

let tools_with_readonly_annotation: HashSet<String> =
vec!["file_reader".to_string()].into_iter().collect();
let tools_without_annotation: HashSet<String> =
vec!["data_fetcher".to_string()].into_iter().collect();

permission_manager.update_user_permission("file_reader", PermissionLevel::AlwaysAllow);
permission_manager
.update_smart_approve_permission("data_fetcher", PermissionLevel::AskBefore);

let tool_request_1 = ToolRequest {
id: "tool_1".to_string(),
tool_call: Ok(CallToolRequestParam {
name: "file_reader".into(),
arguments: Some(object!({"path": "/path/to/file"})),
}),
};

let tool_request_2 = ToolRequest {
id: "tool_2".to_string(),
tool_call: Ok(CallToolRequestParam {
name: "data_fetcher".into(),
arguments: Some(object!({"url": "http://example.com"})),
}),
};

let enable_extension = ToolRequest {
id: "tool_3".to_string(),
tool_call: Ok(CallToolRequestParam {
name: MANAGE_EXTENSIONS_TOOL_NAME_COMPLETE.into(),
arguments: Some(object!({"action": "enable", "extension_name": "data_fetcher"})),
}),
};

let candidate_requests: Vec<ToolRequest> =
vec![tool_request_1, tool_request_2, enable_extension];

// Call the function under test
let (result, enable_extension_request_ids) = check_tool_permissions(
&candidate_requests,
"smart_approve",
tools_with_readonly_annotation,
tools_without_annotation,
&mut permission_manager,
provider,
)
.await;

// Validate the result
assert_eq!(result.approved.len(), 1); // file_reader should be approved
assert_eq!(result.needs_approval.len(), 2); // data_fetcher should need approval
assert_eq!(result.denied.len(), 0); // No tool should be denied in this test
assert_eq!(enable_extension_request_ids.len(), 1);

// Ensure the right tools are in the approved and needs_approval lists
assert!(result.approved.iter().any(|req| req.id == "tool_1"));
assert!(result.needs_approval.iter().any(|req| req.id == "tool_2"));
assert!(result.needs_approval.iter().any(|req| req.id == "tool_3"));
assert!(enable_extension_request_ids.iter().any(|id| id == "tool_3"));
}
}
5 changes: 2 additions & 3 deletions crates/goose/src/providers/formats/databricks.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::conversation::message::{Message, MessageContent};
use crate::model::ModelConfig;
use crate::providers::formats::google as gemini_schema;
use crate::providers::utils::{
convert_image, detect_image_path, is_valid_function_name, load_image_file, safely_parse_json,
sanitize_function_name, ImageFormat,
Expand Down Expand Up @@ -276,9 +277,7 @@ pub fn format_tools(tools: &[Tool], model_name: &str) -> anyhow::Result<Vec<Valu
}

let parameters = if is_gemini {
let mut cleaned_schema = tool.input_schema.as_ref().clone();
cleaned_schema.remove("$schema");
json!(cleaned_schema)
gemini_schema::process_map(tool.input_schema.as_ref(), None)
} else {
json!(tool.input_schema)
};
Expand Down
5 changes: 2 additions & 3 deletions crates/goose/src/providers/formats/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ pub fn format_tools(tools: &[Tool]) -> Vec<Value> {
.collect()
}

/// Get the accepted keys for a given parent key in the JSON schema.
fn get_accepted_keys(parent_key: Option<&str>) -> Vec<&str> {
pub fn get_accepted_keys(parent_key: Option<&str>) -> Vec<&str> {
match parent_key {
Some("properties") => vec![
"anyOf",
Expand All @@ -178,7 +177,7 @@ fn get_accepted_keys(parent_key: Option<&str>) -> Vec<&str> {
/// Process a JSON map to filter out unsupported attributes, mirroring the logic
/// from the official Google Gemini CLI.
/// See: https://github.com/google-gemini/gemini-cli/blob/8a6509ffeba271a8e7ccb83066a9a31a5d72a647/packages/core/src/tools/tool-registry.ts#L356
fn process_map(map: &Map<String, Value>, parent_key: Option<&str>) -> Value {
pub fn process_map(map: &Map<String, Value>, parent_key: Option<&str>) -> Value {
let accepted_keys = get_accepted_keys(parent_key);
let filtered_map: Map<String, Value> = map
.iter()
Expand Down
27 changes: 11 additions & 16 deletions crates/goose/src/providers/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,29 +92,24 @@ pub fn map_http_error_to_provider_error(
ProviderError::ContextLengthExceeded(payload_str)
}
StatusCode::BAD_REQUEST => {
let mut error_msg = "Unknown error".to_string();
let base_msg = format!("Request failed with status: {}", status);
if let Some(payload) = &payload {
let payload_str = payload.to_string();
if check_context_length_exceeded(&payload_str) {
ProviderError::ContextLengthExceeded(payload_str)
} else {
if let Some(error) = payload.get("error") {
error_msg = error
.get("message")
ProviderError::RequestFailed(
payload
.get("error")
.and_then(|e| e.get("message"))
.or_else(|| payload.get("message"))
.and_then(|m| m.as_str())
.unwrap_or("Unknown error")
.to_string();
}
ProviderError::RequestFailed(format!(
"Request failed with status: {}. Message: {}",
status, error_msg
))
.map(|msg| format!("{}. Message: {}", base_msg, msg))
.unwrap_or(base_msg),
)
}
} else {
ProviderError::RequestFailed(format!(
"Request failed with status: {}. Message: {}",
status, error_msg
))
ProviderError::RequestFailed(base_msg)
}
}
StatusCode::TOO_MANY_REQUESTS => ProviderError::RateLimitExceeded {
Expand Down Expand Up @@ -1101,7 +1096,7 @@ mod tests {
StatusCode::BAD_REQUEST,
None,
ProviderError::RequestFailed(
"Request failed with status: 400 Bad Request. Message: Unknown error".to_string(),
"Request failed with status: 400 Bad Request".to_string(),
),
),
// TOO_MANY_REQUESTS
Expand Down
Loading