diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 66b89ea39db9..a283ab2bf7a0 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -73,10 +73,15 @@ fn create_tasks_from_params( } fn create_task_execution_payload(tasks: &[Task], sub_recipe: &SubRecipe) -> Value { + let execution_mode = if tasks.len() == 1 || sub_recipe.sequential_when_repeated { + ExecutionMode::Sequential + } else { + ExecutionMode::Parallel + }; let task_ids: Vec = tasks.iter().map(|task| task.id.clone()).collect(); json!({ "task_ids": task_ids, - "execution_mode": if sub_recipe.sequential_when_repeated { ExecutionMode::Sequential } else { ExecutionMode::Parallel }, + "execution_mode": execution_mode, }) } diff --git a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs index 81d728886eab..faa2bcd5a578 100644 --- a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs @@ -25,18 +25,7 @@ pub async fn execute_tasks( ) .map_err(|e| format!("Failed to parse task_ids: {}", e))?; - let mut tasks = Vec::new(); - for task_id in &task_ids { - match tasks_manager.get_task(task_id).await { - Some(task) => tasks.push(task), - None => { - return Err(format!( - "Task with ID '{}' not found in TasksManager", - task_id - )) - } - } - } + let tasks = tasks_manager.get_tasks(&task_ids).await?; let task_count = tasks.len(); match execution_mode { diff --git a/crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs b/crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs index 334379fa4ef5..4864994b7a0d 100644 --- a/crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs +++ b/crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs @@ -1,3 +1,4 @@ +use anyhow::Result; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; @@ -33,6 +34,22 @@ impl TasksManager { let tasks = self.tasks.read().await; tasks.get(task_id).cloned() } + + pub async fn get_tasks(&self, task_ids: &[String]) -> Result, String> { + let mut tasks = Vec::new(); + for task_id in task_ids { + match self.get_task(task_id).await { + Some(task) => tasks.push(task), + None => { + return Err(format!( + "Task with ID '{}' not found in TasksManager", + task_id + )) + } + } + } + Ok(tasks) + } } #[cfg(test)]