Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

101 changes: 92 additions & 9 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::agents::recipe_tools::dynamic_task_tools::{
use crate::agents::retry::{RetryManager, RetryResult};
use crate::agents::router_tools::ROUTER_LLM_SEARCH_TOOL_NAME;
use crate::agents::sub_recipe_manager::SubRecipeManager;
use crate::agents::subagent_execution_tool::lib::ExecutionMode;
use crate::agents::subagent_execution_tool::subagent_execute_task_tool::{
self, SUBAGENT_EXECUTE_TASK_TOOL_NAME,
};
Expand Down Expand Up @@ -297,14 +298,20 @@ impl Agent {
permission_check_result: &PermissionCheckResult,
message_tool_response: Arc<Mutex<Message>>,
cancel_token: Option<tokio_util::sync::CancellationToken>,
session: Option<SessionConfig>,
) -> Result<Vec<(String, ToolStream)>> {
let mut tool_futures: Vec<(String, ToolStream)> = Vec::new();

// Handle pre-approved and read-only tools
for request in &permission_check_result.approved {
if let Ok(tool_call) = request.tool_call.clone() {
let (req_id, tool_result) = self
.dispatch_tool_call(tool_call, request.id.clone(), cancel_token.clone())
.dispatch_tool_call(
tool_call,
request.id.clone(),
cancel_token.clone(),
session.clone(),
)
.await;

tool_futures.push((
Expand Down Expand Up @@ -384,6 +391,7 @@ impl Agent {
tool_call: CallToolRequestParam,
request_id: String,
cancellation_token: Option<CancellationToken>,
session: Option<SessionConfig>,
) -> (String, Result<ToolCallResult, ErrorData>) {
if tool_call.name == PLATFORM_MANAGE_SCHEDULE_TOOL_NAME {
let arguments = tool_call
Expand Down Expand Up @@ -451,16 +459,89 @@ impl Agent {
.dispatch_sub_recipe_tool_call(&tool_call.name, arguments, &self.tasks_manager)
.await
} else if tool_call.name == SUBAGENT_EXECUTE_TASK_TOOL_NAME {
let provider = self.provider().await.ok();
let arguments = tool_call
.arguments
.clone()
.map(Value::Object)
.unwrap_or(Value::Object(serde_json::Map::new()));
let provider = match self.provider().await {
Ok(p) => p,
Err(_) => {
return (
request_id,
Err(ErrorData::new(
ErrorCode::INTERNAL_ERROR,
"Provider is required".to_string(),
None,
)),
);
}
};
let session = match session.as_ref() {
Some(s) => s,
None => {
return (
request_id,
Err(ErrorData::new(
ErrorCode::INTERNAL_ERROR,
"Session is required".to_string(),
None,
)),
);
}
};
let parent_session_id = session.id.to_string();
let parent_working_dir = session.working_dir.clone();

let task_config = TaskConfig::new(
provider,
parent_session_id,
parent_working_dir,
get_enabled_extensions(),
);

let arguments = match tool_call.arguments.clone() {
Some(args) => Value::Object(args),
None => {
return (
request_id,
Err(ErrorData::new(
ErrorCode::INVALID_PARAMS,
"Tool call arguments are required".to_string(),
None,
)),
);
}
};
let task_ids: Vec<String> = match arguments.get("task_ids") {
Some(v) => match serde_json::from_value(v.clone()) {
Ok(ids) => ids,
Err(_) => {
return (
request_id,
Err(ErrorData::new(
ErrorCode::INVALID_PARAMS,
"Invalid task_ids format".to_string(),
None,
)),
);
}
},
None => {
return (
request_id,
Err(ErrorData::new(
ErrorCode::INVALID_PARAMS,
"task_ids parameter is required".to_string(),
None,
)),
);
}
};

let execution_mode = arguments
.get("execution_mode")
.and_then(|v| serde_json::from_value::<ExecutionMode>(v.clone()).ok())
.unwrap_or(ExecutionMode::Sequential);

let task_config = TaskConfig::new(provider);
subagent_execute_task_tool::run_tasks(
arguments,
task_ids,
execution_mode,
task_config,
&self.tasks_manager,
cancellation_token,
Expand Down Expand Up @@ -1162,6 +1243,7 @@ impl Agent {
&permission_check_result,
message_tool_response.clone(),
cancel_token.clone(),
session.clone(),
).await?;

let tool_futures_arc = Arc::new(Mutex::new(tool_futures));
Expand All @@ -1172,6 +1254,7 @@ impl Agent {
tool_futures_arc.clone(),
message_tool_response.clone(),
cancel_token.clone(),
session.clone(),
&inspection_results,
);

Expand Down
2 changes: 0 additions & 2 deletions crates/goose/src/agents/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ mod router_tool_selector;
mod router_tools;
mod schedule_tool;
pub mod sub_recipe_manager;
pub mod subagent;
pub mod subagent_execution_tool;
pub mod subagent_handler;
mod subagent_task_config;
Expand All @@ -30,6 +29,5 @@ pub use agent::{Agent, AgentEvent};
pub use extension::ExtensionConfig;
pub use extension_manager::ExtensionManager;
pub use prompt_manager::PromptManager;
pub use subagent::{SubAgent, SubAgentProgress, SubAgentStatus};
pub use subagent_task_config::TaskConfig;
pub use types::{FrontendTool, RetryConfig, SessionConfig, SuccessCheck};
42 changes: 0 additions & 42 deletions crates/goose/src/agents/reply_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,48 +84,6 @@ impl Agent {
Ok((tools, toolshim_tools, system_prompt))
}

/// Generate a response from the LLM provider
/// Handles toolshim transformations if needed
pub(crate) async fn generate_response_from_provider(
provider: Arc<dyn Provider>,
system_prompt: &str,
messages: &[Message],
tools: &[Tool],
toolshim_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let config = provider.get_model_config();

// Convert tool messages to text if toolshim is enabled
let messages_for_provider = if config.toolshim {
convert_tool_messages_to_text(messages)
} else {
Conversation::new_unvalidated(messages.to_vec())
};

// Call the provider to get a response
let (mut response, mut usage) = provider
.complete(system_prompt, messages_for_provider.messages(), tools)
.await?;

// Ensure we have token counts, estimating if necessary
usage
.ensure_tokens(
system_prompt,
messages_for_provider.messages(),
&response,
tools,
)
.await?;

crate::providers::base::set_current_model(&usage.model);

if config.toolshim {
response = toolshim_postprocess(response, toolshim_tools).await?;
}

Ok((response, usage))
}

/// Stream a response from the LLM provider.
/// Handles toolshim transformations if needed
pub(crate) async fn stream_response_from_provider(
Expand Down
Loading
Loading