Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion crates/goose-cli/src/session/export.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use goose::conversation::message::{Message, MessageContent, ToolRequest, ToolResponse};
use goose::conversation::message::{
ActionRequiredData, Message, MessageContent, ToolRequest, ToolResponse,
};
use goose::utils::safe_truncate;
use rmcp::model::{RawContent, ResourceContents, Role};
use serde_json::Value;
Expand Down Expand Up @@ -340,6 +342,14 @@ pub fn message_to_markdown(message: &Message, export_all_content: bool) -> Strin
let mut md = String::new();
for content in &message.content {
match content {
MessageContent::ActionRequired(action) => match &action.data {
ActionRequiredData::ToolConfirmation { tool_name, .. } => {
md.push_str(&format!(
"**Action Required** (tool_confirmation): {}\n\n",
tool_name
));
}
},
MessageContent::Text(text) => {
md.push_str(&text.text);
md.push_str("\n\n");
Expand Down
35 changes: 23 additions & 12 deletions crates/goose-cli/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use rmcp::model::ServerNotification;
use rmcp::model::{ErrorCode, ErrorData};

use goose::config::paths::Paths;
use goose::conversation::message::{Message, MessageContent};
use goose::conversation::message::{ActionRequiredData, Message, MessageContent};
use rand::{distributions::Alphanumeric, Rng};
use rustyline::EditMode;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -852,20 +852,32 @@ impl CliSession {
result = stream.next() => {
match result {
Some(Ok(AgentEvent::Message(message))) => {
// If it's a confirmation request, get approval but otherwise do not render/persist
if let Some(MessageContent::ToolConfirmationRequest(confirmation)) = message.content.first() {
let tool_call_confirmation = message.content.iter().find_map(|content| {
if let MessageContent::ActionRequired(action) = content {
#[allow(irrefutable_let_patterns)] // this is a one variant enum right now but it will have more
if let ActionRequiredData::ToolConfirmation { id, tool_name, arguments, prompt } = &action.data {
Some((id.clone(), tool_name.clone(), arguments.clone(), prompt.clone()))
} else {
None
}
} else {
None
}
});

if let Some((id, _tool_name, _arguments, security_prompt)) = tool_call_confirmation {
output::hide_thinking();

// Format the confirmation prompt - use security message if present, otherwise use generic message
let prompt = if let Some(security_message) = &confirmation.prompt {
let prompt = if let Some(security_message) = &security_prompt {
println!("\n{}", security_message);
"Do you allow this tool call?".to_string()
} else {
"Goose would like to call the above tool, do you allow?".to_string()
};

// Get confirmation from user
let permission_result = if confirmation.prompt.is_none() {
let permission_result = if security_prompt.is_none() {
// No security message - show all options including "Always Allow"
cliclack::select(prompt)
.item(Permission::AllowOnce, "Allow", "Allow the tool call once")
Expand All @@ -883,13 +895,12 @@ impl CliSession {
};

let permission = match permission_result {
Ok(p) => p, // If Ok, use the selected permission
Ok(p) => p,
Err(e) => {
// Check if the error is an interruption (Ctrl+C/Cmd+C, Escape)
if e.kind() == std::io::ErrorKind::Interrupted {
Permission::Cancel // If interrupted, set permission to Cancel
Permission::Cancel
} else {
return Err(e.into()); // Otherwise, convert and propagate the original error
return Err(e.into());
}
}
};
Expand All @@ -899,18 +910,18 @@ impl CliSession {

let mut response_message = Message::user();
response_message.content.push(MessageContent::tool_response(
confirmation.id.clone(),
id.clone(),
Err(ErrorData { code: ErrorCode::INVALID_REQUEST, message: std::borrow::Cow::from("Tool call cancelled by user".to_string()), data: None })
));
self.messages.push(response_message);
cancel_token_clone.cancel();
drop(stream);
break;
} else {
self.agent.handle_confirmation(confirmation.id.clone(), PermissionConfirmation {
self.agent.handle_confirmation(id.clone(), PermissionConfirmation {
principal_type: PrincipalType::Tool,
permission,
},).await;
}).await;
}
}
else {
Expand Down
9 changes: 8 additions & 1 deletion crates/goose-cli/src/session/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use anstream::println;
use bat::WrappingMode;
use console::{measure_text_width, style, Color, Term};
use goose::config::Config;
use goose::conversation::message::{Message, MessageContent, ToolRequest, ToolResponse};
use goose::conversation::message::{
ActionRequiredData, Message, MessageContent, ToolRequest, ToolResponse,
};
use goose::providers::pricing::get_model_pricing;
use goose::providers::pricing::parse_model_id;
use goose::utils::safe_truncate;
Expand Down Expand Up @@ -166,6 +168,11 @@ pub fn render_message(message: &Message, debug: bool) {

for content in &message.content {
match content {
MessageContent::ActionRequired(action) => match &action.data {
ActionRequiredData::ToolConfirmation { tool_name, .. } => {
println!("action_required(tool_confirmation): {}", tool_name)
}
},
MessageContent::Text(text) => print_markdown(&text.text, theme),
MessageContent::ToolRequest(req) => render_tool_request(req, theme, debug),
MessageContent::ToolResponse(resp) => render_tool_response(resp, theme, debug),
Expand Down
12 changes: 7 additions & 5 deletions crates/goose-server/src/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use goose::config::declarative_providers::{
DeclarativeProviderConfig, LoadedProvider, ProviderEngine,
};
use goose::conversation::message::{
FrontendToolRequest, Message, MessageContent, MessageMetadata, RedactedThinkingContent,
SystemNotificationContent, SystemNotificationType, ThinkingContent, TokenState,
ToolConfirmationRequest, ToolRequest, ToolResponse,
ActionRequired, ActionRequiredData, FrontendToolRequest, Message, MessageContent,
MessageMetadata, RedactedThinkingContent, SystemNotificationContent, SystemNotificationType,
ThinkingContent, TokenState, ToolConfirmationRequest, ToolRequest, ToolResponse,
};

use crate::routes::recipe_utils::RecipeManifest;
Expand Down Expand Up @@ -358,7 +358,7 @@ derive_utoipa!(Icon as IconSchema);
super::routes::agent::agent_remove_extension,
super::routes::agent::update_agent_provider,
super::routes::agent::update_router_tool_selector,
super::routes::reply::confirm_permission,
super::routes::action_required::confirm_tool_action,
super::routes::reply::reply,
super::routes::session::list_sessions,
super::routes::session::get_session,
Expand Down Expand Up @@ -411,7 +411,7 @@ derive_utoipa!(Icon as IconSchema);
super::routes::config_management::UpdateCustomProviderRequest,
super::routes::config_management::CheckProviderRequest,
super::routes::config_management::SetProviderRequest,
super::routes::reply::PermissionConfirmationRequest,
super::routes::action_required::ConfirmToolActionRequest,
super::routes::reply::ChatRequest,
super::routes::session::ImportSessionRequest,
super::routes::session::SessionListResponse,
Expand All @@ -438,6 +438,8 @@ derive_utoipa!(Icon as IconSchema);
ToolResponse,
ToolRequest,
ToolConfirmationRequest,
ActionRequired,
ActionRequiredData,
ThinkingContent,
RedactedThinkingContent,
FrontendToolRequest,
Expand Down
104 changes: 104 additions & 0 deletions crates/goose-server/src/routes/action_required.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use crate::state::AppState;
use axum::{extract::State, http::StatusCode, routing::post, Json, Router};
use goose::permission::permission_confirmation::PrincipalType;
use goose::permission::{Permission, PermissionConfirmation};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;
use utoipa::ToSchema;

#[derive(Debug, Deserialize, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ConfirmToolActionRequest {
id: String,
#[serde(default = "default_principal_type")]
principal_type: PrincipalType,
action: String,
session_id: String,
}

fn default_principal_type() -> PrincipalType {
PrincipalType::Tool
}

#[utoipa::path(
post,
path = "/action-required/tool-confirmation",
request_body = ConfirmToolActionRequest,
responses(
(status = 200, description = "Tool confirmation action is confirmed", body = Value),
(status = 401, description = "Unauthorized - invalid secret key"),
(status = 500, description = "Internal server error")
)
)]
pub async fn confirm_tool_action(
State(state): State<Arc<AppState>>,
Json(request): Json<ConfirmToolActionRequest>,
) -> Result<Json<Value>, StatusCode> {
let agent = state.get_agent_for_route(request.session_id).await?;
let permission = match request.action.as_str() {
"always_allow" => Permission::AlwaysAllow,
"allow_once" => Permission::AllowOnce,
"deny" => Permission::DenyOnce,
_ => Permission::DenyOnce,
};

agent
.handle_confirmation(
request.id.clone(),
PermissionConfirmation {
principal_type: request.principal_type,
permission,
},
)
.await;

Ok(Json(Value::Object(serde_json::Map::new())))
}

pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route(
"/action-required/tool-confirmation",
post(confirm_tool_action),
)
.with_state(state)
}

#[cfg(test)]
mod tests {
use super::*;

mod integration_tests {
use super::*;
use axum::{body::Body, http::Request};
use tower::ServiceExt;

#[tokio::test(flavor = "multi_thread")]
async fn test_tool_confirmation_endpoint() {
let state = AppState::new().await.unwrap();

let app = routes(state);

let request = Request::builder()
.uri("/action-required/tool-confirmation")
.method("POST")
.header("content-type", "application/json")
.header("x-secret-key", "test-secret")
.body(Body::from(
serde_json::to_string(&ConfirmToolActionRequest {
id: "test-id".to_string(),
principal_type: PrincipalType::Tool,
action: "allow_once".to_string(),
session_id: "test-session".to_string(),
})
.unwrap(),
))
.unwrap();

let response = app.oneshot(request).await.unwrap();

assert_eq!(response.status(), StatusCode::OK);
}
}
}
2 changes: 2 additions & 0 deletions crates/goose-server/src/routes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod action_required;
pub mod agent;
pub mod audio;
pub mod config_management;
Expand All @@ -22,6 +23,7 @@ pub fn configure(state: Arc<crate::state::AppState>, secret_key: String) -> Rout
Router::new()
.merge(status::routes())
.merge(reply::routes(state.clone()))
.merge(action_required::routes(state.clone()))
.merge(agent::routes(state.clone()))
.merge(audio::routes(state.clone()))
.merge(config_management::routes(state.clone()))
Expand Down
56 changes: 1 addition & 55 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,12 @@ use axum::{
};
use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use goose::agents::{AgentEvent, SessionConfig};
use goose::conversation::message::{Message, MessageContent, TokenState};
use goose::conversation::Conversation;
use goose::permission::{Permission, PermissionConfirmation};
use goose::session::SessionManager;
use goose::{
agents::{AgentEvent, SessionConfig},
permission::permission_confirmation::PrincipalType,
};
use rmcp::model::ServerNotification;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{
convert::Infallible,
pin::Pin,
Expand All @@ -30,7 +25,6 @@ use tokio::sync::mpsc;
use tokio::time::timeout;
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::CancellationToken;
use utoipa::ToSchema;

fn track_tool_telemetry(content: &MessageContent, all_messages: &[Message]) {
match content {
Expand Down Expand Up @@ -452,60 +446,12 @@ pub async fn reply(
Ok(SseResponse::new(stream))
}

#[derive(Debug, Deserialize, Serialize, ToSchema)]
pub struct PermissionConfirmationRequest {
id: String,
#[serde(default = "default_principal_type")]
principal_type: PrincipalType,
action: String,
session_id: String,
}

fn default_principal_type() -> PrincipalType {
PrincipalType::Tool
}

#[utoipa::path(
post,
path = "/confirm",
request_body = PermissionConfirmationRequest,
responses(
(status = 200, description = "Permission action is confirmed", body = Value),
(status = 401, description = "Unauthorized - invalid secret key"),
(status = 500, description = "Internal server error")
)
)]
pub async fn confirm_permission(
State(state): State<Arc<AppState>>,
Json(request): Json<PermissionConfirmationRequest>,
) -> Result<Json<Value>, StatusCode> {
let agent = state.get_agent_for_route(request.session_id).await?;
let permission = match request.action.as_str() {
"always_allow" => Permission::AlwaysAllow,
"allow_once" => Permission::AllowOnce,
"deny" => Permission::DenyOnce,
_ => Permission::DenyOnce,
};

agent
.handle_confirmation(
request.id.clone(),
PermissionConfirmation {
principal_type: request.principal_type,
permission,
},
)
.await;
Ok(Json(Value::Object(serde_json::Map::new())))
}

pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route(
"/reply",
post(reply).layer(DefaultBodyLimit::max(50 * 1024 * 1024)),
)
.route("/confirm", post(confirm_permission))
.with_state(state)
}

Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/agents/tool_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl Agent {
});

let confirmation = Message::assistant()
.with_tool_confirmation_request(
.with_action_required(
request.id.clone(),
tool_call.name.to_string().clone(),
tool_call.arguments.clone().unwrap_or_default(),
Expand Down
Loading