From 75d3cd1a8aaced66754156ca58cb1b451f7a0d48 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Mon, 7 Jul 2025 13:10:24 +1000 Subject: [PATCH 01/43] initial version of run sub recipe multiple times --- .../goose-cli/src/recipes/extract_from_cli.rs | 1 + .../agents/recipe_tools/sub_recipe_tools.rs | 187 ++++++++++++++++++ .../recipe_tools/sub_recipe_tools/tests.rs | 74 +++++++ .../sub_recipe_execute_task_tool.rs | 22 ++- crates/goose/src/agents/sub_recipe_manager.rs | 14 +- crates/goose/src/recipe/mod.rs | 15 ++ 6 files changed, 298 insertions(+), 15 deletions(-) diff --git a/crates/goose-cli/src/recipes/extract_from_cli.rs b/crates/goose-cli/src/recipes/extract_from_cli.rs index 0550ee26fef4..84d578c0cae4 100644 --- a/crates/goose-cli/src/recipes/extract_from_cli.rs +++ b/crates/goose-cli/src/recipes/extract_from_cli.rs @@ -32,6 +32,7 @@ pub fn extract_recipe_info_from_cli( path: recipe_file_path.to_string_lossy().to_string(), name, values: None, + executions: None, }; all_sub_recipes.push(additional_sub_recipe); } diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 928cf8bd0845..7ef2e4737bc6 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::{collections::HashMap, fs}; use anyhow::Result; @@ -9,6 +10,7 @@ use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubReci pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; +#[allow(dead_code)] pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { let input_schema = get_input_schema(sub_recipe).unwrap(); Tool::new( @@ -25,6 +27,32 @@ pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { ) } +#[allow(dead_code)] +pub fn create_multiple_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { + let input_schema = get_input_schema_for_multiple_sub_recipe_task(sub_recipe).unwrap(); + Tool::new( + format!("{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name), + format!( + "Create one or more tasks to run the '{}' sub recipe. \ + Provide an array of parameter sets in the 'task_parameters' field:\n\ + - For a single task: provide an array with one parameter set\n\ + - For multiple tasks: provide an array with multiple parameter sets, each with different values\n\n\ + Each task will run the same sub recipe but with different parameter values. \ + This is useful when you need to execute the same sub recipe multiple times with varying inputs. \ + After creating the task list, pass it to the task executor to run all tasks.", + sub_recipe.name + ), + input_schema, + Some(ToolAnnotations { + title: Some(format!("create multiple sub recipe tasks for {}", sub_recipe.name)), + read_only_hint: false, + destructive_hint: true, + idempotent_hint: false, + open_world_hint: true, + }), + ) +} + fn get_sub_recipe_parameter_definition( sub_recipe: &SubRecipe, ) -> Result>> { @@ -34,6 +62,7 @@ fn get_sub_recipe_parameter_definition( Ok(recipe.parameters) } +#[allow(dead_code)] fn get_input_schema(sub_recipe: &SubRecipe) -> Result { let mut sub_recipe_params_map = HashMap::::new(); if let Some(params_with_value) = &sub_recipe.values { @@ -73,6 +102,74 @@ fn get_input_schema(sub_recipe: &SubRecipe) -> Result { } } +#[allow(dead_code)] +fn get_input_schema_for_multiple_sub_recipe_task(sub_recipe: &SubRecipe) -> Result { + let mut sub_recipe_values = HashSet::::new(); + if let Some(params_with_value) = &sub_recipe.values { + for param_name in params_with_value.keys() { + sub_recipe_values.insert(param_name.clone()); + } + } + + if let Some(runs) = sub_recipe.executions.as_ref().and_then(|e| e.runs.as_ref()) { + for run in runs { + if let Some(params_with_value) = &run.values { + for param_name in params_with_value.keys() { + sub_recipe_values.insert(param_name.clone()); + } + } + } + } + + let parameter_definition = get_sub_recipe_parameter_definition(sub_recipe)?; + + let mut param_properties = Map::new(); + let mut param_required = Vec::new(); + + if let Some(parameters) = parameter_definition { + for param in parameters { + if sub_recipe_values.contains(¶m.key.clone()) { + continue; + } + param_properties.insert( + param.key.clone(), + json!({ + "type": param.input_type.to_string(), + "description": param.description.clone(), + }), + ); + if !matches!(param.requirement, RecipeParameterRequirement::Optional) { + param_required.push(param.key); + } + } + } + + // Create the schema using only task_parameters for both single and multiple tasks + let mut properties = Map::new(); + if !param_properties.is_empty() { + properties.insert( + "task_parameters".to_string(), + json!({ + "type": "array", + "description": "Array of parameter sets for creating tasks. \ + For a single task, provide an array with one element. \ + For multiple tasks, provide an array with multiple elements, each with different parameter values. \ + If there is no parameter set, provide an empty array.", + "items": { + "type": "object", + "properties": param_properties, + "required": param_required + }, + }) + ); + } + Ok(json!({ + "type": "object", + "properties": properties, + })) +} + +#[allow(dead_code)] fn prepare_command_params( sub_recipe: &SubRecipe, params_from_tool_call: Value, @@ -94,6 +191,51 @@ fn prepare_command_params( Ok(sub_recipe_params) } +fn prepare_command_params_for_multiple_sub_recipe_task( + sub_recipe: &SubRecipe, + params_from_tool_call: Vec, +) -> Result>> { + let mut sub_recipe_params = HashMap::::new(); + if let Some(params_with_value) = &sub_recipe.values { + for (param_name, param_value) in params_with_value { + sub_recipe_params.insert(param_name.clone(), param_value.clone()); + } + } + let mut sub_recipe_run_params = Vec::>::new(); + if let Some(runs) = sub_recipe.executions.as_ref().and_then(|e| e.runs.as_ref()) { + for run in runs { + let mut sub_recipe_run_param = sub_recipe_params.clone(); + if let Some(params_with_value) = &run.values { + sub_recipe_run_param.extend(params_with_value.clone()); + } + sub_recipe_run_params.push(sub_recipe_run_param); + } + } + println!("===== sub_recipe_run_params: {:?}", sub_recipe_run_params); + println!("===== params_from_tool_call: {:?}", params_from_tool_call); + if params_from_tool_call.is_empty() { + return Ok(sub_recipe_run_params); + } + if sub_recipe_run_params.len() > 0 && sub_recipe_run_params.len() != params_from_tool_call.len() + { + return Err(anyhow::anyhow!( + "The number of runs in the sub recipe does not match the number of task parameters" + )); + } + let mut sub_recipe_run_params = vec![sub_recipe_params.clone(); params_from_tool_call.len()]; + for (index, sub_recipe_task_param) in params_from_tool_call.iter().enumerate() { + if let Some(params_with_value) = sub_recipe_task_param.as_object() { + for (key, value) in params_with_value { + sub_recipe_run_params[index] + .entry(key.to_string()) + .or_insert_with(|| value.as_str().unwrap_or(&value.to_string()).to_string()); + } + } + } + Ok(sub_recipe_run_params) +} + +#[allow(dead_code)] pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result { let command_params = prepare_command_params(sub_recipe, params)?; let payload = json!({ @@ -113,5 +255,50 @@ pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Re Ok(task_json) } +pub async fn create_multiple_sub_recipe_tasks( + sub_recipe: &SubRecipe, + params: Value, +) -> Result { + // Get the task_parameters array + let empty_vec = vec![]; + let task_params_array = params + .get("task_parameters") + .and_then(|v| v.as_array()) + .unwrap_or(&empty_vec); + let command_params = + prepare_command_params_for_multiple_sub_recipe_task(sub_recipe, task_params_array.clone())?; + let tasks = command_params + .iter() + .map(|task_command_param| { + let payload = json!({ + "sub_recipe": { + "name": sub_recipe.name.clone(), + "command_parameters": task_command_param, + "recipe_path": sub_recipe.path.clone(), + } + }); + Task { + id: uuid::Uuid::new_v4().to_string(), + task_type: "sub_recipe".to_string(), + payload, + } + }) + .collect::>(); + let is_parallel = sub_recipe + .executions + .as_ref() + .map(|e| e.parallel) + .unwrap_or(false); + let task_execution_payload = json!({ + "tasks": tasks, + "execution_mode": if is_parallel { "parallel" } else { "sequential" } + }); + println!("===== task_execution_payload: {:?}", task_execution_payload); + + let tasks_json = serde_json::to_string(&task_execution_payload) + .map_err(|e| anyhow::anyhow!("Failed to serialize task list: {}", e))?; + Ok(tasks_json) +} + #[cfg(test)] mod tests; diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index 11ce390a6b3b..4fa8574d5b13 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -9,6 +9,7 @@ mod tests { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), + executions: None, }; sub_recipe } @@ -42,6 +43,7 @@ mod tests { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), values: None, + executions: None, }; let params: HashMap = HashMap::new(); let params_value = serde_json::to_value(params).unwrap(); @@ -113,6 +115,7 @@ mod tests { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), values: None, + executions: None, }; let sub_recipe_file_content = r#"{ @@ -152,4 +155,75 @@ mod tests { assert_eq!(result["required"][0], "key1"); } } + + mod get_input_schema_for_multiple_sub_recipe_task_tests { + use crate::{ + agents::recipe_tools::sub_recipe_tools::get_input_schema_for_multiple_sub_recipe_task, + recipe::SubRecipe, + }; + + #[test] + fn test_get_input_schema_for_multiple_tasks() { + let sub_recipe_file_content = r#"{ + "version": "1.0.0", + "title": "Test Recipe", + "description": "A test recipe", + "prompt": "Test prompt", + "parameters": [ + { + "key": "param1", + "input_type": "string", + "requirement": "required", + "description": "A required parameter" + }, + { + "key": "param2", + "input_type": "number", + "requirement": "optional", + "description": "An optional parameter" + } + ] + }"#; + + let temp_dir = tempfile::tempdir().unwrap(); + let temp_file = temp_dir.path().join("test_sub_recipe.yaml"); + std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); + + let sub_recipe = SubRecipe { + name: "test_sub_recipe".to_string(), + path: temp_file.to_string_lossy().to_string(), + values: None, + executions: None, + }; + + let result = get_input_schema_for_multiple_sub_recipe_task(&sub_recipe).unwrap(); + + // Verify the schema structure + assert_eq!(result["type"], "object"); + assert!(result["properties"].is_object()); + + let properties = result["properties"].as_object().unwrap(); + assert_eq!(properties.len(), 1); + assert!(properties.contains_key("task_parameters")); + + let task_params = &properties["task_parameters"]; + assert_eq!(task_params["type"], "array"); + assert_eq!(task_params["minItems"], 1); + + let items = &task_params["items"]; + assert_eq!(items["type"], "object"); + + let item_properties = items["properties"].as_object().unwrap(); + assert_eq!(item_properties.len(), 2); + assert!(item_properties.contains_key("param1")); + assert!(item_properties.contains_key("param2")); + + assert_eq!(item_properties["param1"]["type"], "string"); + assert_eq!(item_properties["param2"]["type"], "number"); + + let required = items["required"].as_array().unwrap(); + assert_eq!(required.len(), 1); + assert_eq!(required[0], "param1"); + } + } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 46738b813b13..0e7e061fd4d1 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -10,19 +10,31 @@ pub fn create_sub_recipe_execute_task_tool() -> Tool { Tool::new( SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, "Only use this tool when you execute sub recipe task. -EXECUTION STRATEGY: -- DEFAULT: Execute tasks sequentially (one at a time) unless user explicitly requests parallel execution -- PARALLEL: Only when user explicitly uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently' + +EXECUTION STRATEGY DECISION: +1. PRE-CREATED TASKS: If tasks were created by subrecipe__create_task_* tools, check the execution_mode in the response: + - If execution_mode is 'parallel', use parallel execution + - If execution_mode is 'sequential', use sequential execution + - Always respect the execution_mode from task creation to maintain consistency + +2. USER INTENT: If creating tasks inline or user explicitly specifies: + - DEFAULT: Execute tasks sequentially unless user explicitly requests parallel execution + - PARALLEL: When user uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently' IMPLEMENTATION: - Sequential execution: Call this tool multiple times, passing exactly ONE task per call - Parallel execution: Call this tool once, passing an ARRAY of all tasks EXAMPLES: +User Intent Based: - User: 'get weather and tell me a joke' → Sequential (2 separate tool calls, 1 task each) - User: 'get weather and joke in parallel' → Parallel (1 tool call with array of 2 tasks) - User: 'run these simultaneously' → Parallel (1 tool call with task array) -- User: 'do task A then task B' → Sequential (2 separate tool calls)", +- User: 'do task A then task B' → Sequential (2 separate tool calls) + +Pre-created Task Based: +- subrecipe__create_task_weather returns execution_mode: 'parallel' → Use parallel execution +- subrecipe__create_task_weather returns execution_mode: 'sequential' → Use sequential execution", serde_json::json!({ "type": "object", "properties": { @@ -30,7 +42,7 @@ EXAMPLES: "type": "string", "enum": ["sequential", "parallel"], "default": "sequential", - "description": "Execution strategy for multiple tasks. Use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'." + "description": "Execution strategy for multiple tasks. For pre-created tasks, respect the execution_mode from task creation. For user intent, use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'." }, "tasks": { "type": "array", diff --git a/crates/goose/src/agents/sub_recipe_manager.rs b/crates/goose/src/agents/sub_recipe_manager.rs index 2441684b4b0e..76d7effb48ee 100644 --- a/crates/goose/src/agents/sub_recipe_manager.rs +++ b/crates/goose/src/agents/sub_recipe_manager.rs @@ -5,7 +5,8 @@ use std::collections::HashMap; use crate::{ agents::{ recipe_tools::sub_recipe_tools::{ - create_sub_recipe_task, create_sub_recipe_task_tool, SUB_RECIPE_TASK_TOOL_NAME_PREFIX, + create_multiple_sub_recipe_task_tool, create_multiple_sub_recipe_tasks, + SUB_RECIPE_TASK_TOOL_NAME_PREFIX, }, tool_execution::ToolCallResult, }, @@ -34,18 +35,12 @@ impl SubRecipeManager { pub fn add_sub_recipe_tools(&mut self, sub_recipes_to_add: Vec) { for sub_recipe in sub_recipes_to_add { - // let sub_recipe_key = format!( - // "{}_{}", - // SUB_RECIPE_TOOL_NAME_PREFIX, - // sub_recipe.name.clone() - // ); - // let tool = create_sub_recipe_tool(&sub_recipe); let sub_recipe_key = format!( "{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name.clone() ); - let tool = create_sub_recipe_task_tool(&sub_recipe); + let tool = create_multiple_sub_recipe_task_tool(&sub_recipe); self.sub_recipe_tools.insert(sub_recipe_key.clone(), tool); self.sub_recipes.insert(sub_recipe_key.clone(), sub_recipe); } @@ -111,8 +106,7 @@ impl SubRecipeManager { ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name)) })?; - - let output = create_sub_recipe_task(sub_recipe, params) + let output = create_multiple_sub_recipe_tasks(sub_recipe, params) .await .map_err(|e| { ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e)) diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 55ff144f781e..75c08a42fa16 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -137,6 +137,21 @@ pub struct SubRecipe { pub path: String, #[serde(default, deserialize_with = "deserialize_value_map_as_string")] pub values: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub executions: Option, +} +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Execution { + #[serde(default)] + pub parallel: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub runs: Option>, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ExecutionRun { + #[serde(default, deserialize_with = "deserialize_value_map_as_string")] + pub values: Option>, } fn deserialize_value_map_as_string<'de, D>( From cb0e6a5db20a54d858e0ac0571da473ccfff91f6 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Mon, 7 Jul 2025 17:53:52 +1000 Subject: [PATCH 02/43] added tests, rename functions --- .../agents/recipe_tools/sub_recipe_tools.rs | 321 ++++------- .../recipe_tools/sub_recipe_tools/tests.rs | 520 ++++++++++++------ crates/goose/src/agents/sub_recipe_manager.rs | 7 +- 3 files changed, 484 insertions(+), 364 deletions(-) diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 7ef2e4737bc6..e50ee588700f 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -10,26 +10,8 @@ use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubReci pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; -#[allow(dead_code)] pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { let input_schema = get_input_schema(sub_recipe).unwrap(); - Tool::new( - format!("{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name), - "Before running this sub recipe, you should first create a task with this tool and then pass the task to the task executor".to_string(), - input_schema, - Some(ToolAnnotations { - title: Some(format!("create sub recipe task {}", sub_recipe.name)), - read_only_hint: false, - destructive_hint: true, - idempotent_hint: false, - open_world_hint: true, - }), - ) -} - -#[allow(dead_code)] -pub fn create_multiple_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { - let input_schema = get_input_schema_for_multiple_sub_recipe_task(sub_recipe).unwrap(); Tool::new( format!("{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name), format!( @@ -53,6 +35,45 @@ pub fn create_multiple_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { ) } +pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result { + let empty_vec = vec![]; + let task_params_array = params + .get("task_parameters") + .and_then(|v| v.as_array()) + .unwrap_or(&empty_vec); + let command_params = prepare_command_params(sub_recipe, task_params_array.clone())?; + let tasks = command_params + .iter() + .map(|task_command_param| { + let payload = json!({ + "sub_recipe": { + "name": sub_recipe.name.clone(), + "command_parameters": task_command_param, + "recipe_path": sub_recipe.path.clone(), + } + }); + Task { + id: uuid::Uuid::new_v4().to_string(), + task_type: "sub_recipe".to_string(), + payload, + } + }) + .collect::>(); + let is_parallel = sub_recipe + .executions + .as_ref() + .map(|e| e.parallel) + .unwrap_or(false); + let task_execution_payload = json!({ + "tasks": tasks, + "execution_mode": if is_parallel { "parallel" } else { "sequential" } + }); + + let tasks_json = serde_json::to_string(&task_execution_payload) + .map_err(|e| anyhow::anyhow!("Failed to serialize task list: {}", e))?; + Ok(tasks_json) +} + fn get_sub_recipe_parameter_definition( sub_recipe: &SubRecipe, ) -> Result>> { @@ -62,89 +83,26 @@ fn get_sub_recipe_parameter_definition( Ok(recipe.parameters) } -#[allow(dead_code)] -fn get_input_schema(sub_recipe: &SubRecipe) -> Result { - let mut sub_recipe_params_map = HashMap::::new(); - if let Some(params_with_value) = &sub_recipe.values { - for (param_name, param_value) in params_with_value { - sub_recipe_params_map.insert(param_name.clone(), param_value.clone()); - } - } - let parameter_definition = get_sub_recipe_parameter_definition(sub_recipe)?; - if let Some(parameters) = parameter_definition { - let mut properties = Map::new(); - let mut required = Vec::new(); - for param in parameters { - if sub_recipe_params_map.contains_key(¶m.key) { - continue; - } - properties.insert( - param.key.clone(), - json!({ - "type": param.input_type.to_string(), - "description": param.description.clone(), - }), - ); - if !matches!(param.requirement, RecipeParameterRequirement::Optional) { - required.push(param.key); - } - } - Ok(json!({ - "type": "object", - "properties": properties, - "required": required - })) - } else { - Ok(json!({ - "type": "object", - "properties": {} - })) - } -} - -#[allow(dead_code)] -fn get_input_schema_for_multiple_sub_recipe_task(sub_recipe: &SubRecipe) -> Result { - let mut sub_recipe_values = HashSet::::new(); +fn get_params_with_values(sub_recipe: &SubRecipe) -> HashSet { + let mut sub_recipe_params_with_values = HashSet::::new(); if let Some(params_with_value) = &sub_recipe.values { for param_name in params_with_value.keys() { - sub_recipe_values.insert(param_name.clone()); + sub_recipe_params_with_values.insert(param_name.clone()); } } - if let Some(runs) = sub_recipe.executions.as_ref().and_then(|e| e.runs.as_ref()) { for run in runs { if let Some(params_with_value) = &run.values { for param_name in params_with_value.keys() { - sub_recipe_values.insert(param_name.clone()); + sub_recipe_params_with_values.insert(param_name.clone()); } } } } + sub_recipe_params_with_values +} - let parameter_definition = get_sub_recipe_parameter_definition(sub_recipe)?; - - let mut param_properties = Map::new(); - let mut param_required = Vec::new(); - - if let Some(parameters) = parameter_definition { - for param in parameters { - if sub_recipe_values.contains(¶m.key.clone()) { - continue; - } - param_properties.insert( - param.key.clone(), - json!({ - "type": param.input_type.to_string(), - "description": param.description.clone(), - }), - ); - if !matches!(param.requirement, RecipeParameterRequirement::Optional) { - param_required.push(param.key); - } - } - } - - // Create the schema using only task_parameters for both single and multiple tasks +fn create_input_schema(param_properties: Map, param_required: Vec) -> Value { let mut properties = Map::new(); if !param_properties.is_empty() { properties.insert( @@ -163,141 +121,106 @@ fn get_input_schema_for_multiple_sub_recipe_task(sub_recipe: &SubRecipe) -> Resu }) ); } - Ok(json!({ + json!({ "type": "object", "properties": properties, - })) + }) } -#[allow(dead_code)] -fn prepare_command_params( - sub_recipe: &SubRecipe, - params_from_tool_call: Value, -) -> Result> { - let mut sub_recipe_params = HashMap::::new(); - if let Some(params_with_value) = &sub_recipe.values { - for (param_name, param_value) in params_with_value { - sub_recipe_params.insert(param_name.clone(), param_value.clone()); - } - } - if let Some(params_map) = params_from_tool_call.as_object() { - for (key, value) in params_map { - sub_recipe_params.insert( - key.to_string(), - value.as_str().unwrap_or(&value.to_string()).to_string(), +fn get_input_schema(sub_recipe: &SubRecipe) -> Result { + let sub_recipe_params_with_values = get_params_with_values(sub_recipe); + + let parameter_definition = get_sub_recipe_parameter_definition(sub_recipe)?; + + let mut param_properties = Map::new(); + let mut param_required = Vec::new(); + + if let Some(parameters) = parameter_definition { + for param in parameters { + if sub_recipe_params_with_values.contains(¶m.key.clone()) { + continue; + } + param_properties.insert( + param.key.clone(), + json!({ + "type": param.input_type.to_string(), + "description": param.description.clone(), + }), ); + if !matches!(param.requirement, RecipeParameterRequirement::Optional) { + param_required.push(param.key); + } } } - Ok(sub_recipe_params) + Ok(create_input_schema(param_properties, param_required)) +} + +fn extract_run_params( + sub_recipe: &SubRecipe, +) -> (HashMap, Vec>) { + let base_params = sub_recipe.values.clone().unwrap_or_default(); + + let run_params = sub_recipe + .executions + .as_ref() + .and_then(|e| e.runs.as_ref()) + .map(|runs| { + runs.iter() + .map(|run| { + let mut params = base_params.clone(); + if let Some(run_values) = &run.values { + params.extend(run_values.clone()); + } + params + }) + .collect::>() + }) + .unwrap_or_default(); + (base_params, run_params) } -fn prepare_command_params_for_multiple_sub_recipe_task( +fn prepare_command_params( sub_recipe: &SubRecipe, params_from_tool_call: Vec, ) -> Result>> { - let mut sub_recipe_params = HashMap::::new(); - if let Some(params_with_value) = &sub_recipe.values { - for (param_name, param_value) in params_with_value { - sub_recipe_params.insert(param_name.clone(), param_value.clone()); - } - } - let mut sub_recipe_run_params = Vec::>::new(); - if let Some(runs) = sub_recipe.executions.as_ref().and_then(|e| e.runs.as_ref()) { - for run in runs { - let mut sub_recipe_run_param = sub_recipe_params.clone(); - if let Some(params_with_value) = &run.values { - sub_recipe_run_param.extend(params_with_value.clone()); - } - sub_recipe_run_params.push(sub_recipe_run_param); - } - } - println!("===== sub_recipe_run_params: {:?}", sub_recipe_run_params); - println!("===== params_from_tool_call: {:?}", params_from_tool_call); + let (base_params, run_params) = extract_run_params(sub_recipe); + if params_from_tool_call.is_empty() { - return Ok(sub_recipe_run_params); + return Ok(run_params); } - if sub_recipe_run_params.len() > 0 && sub_recipe_run_params.len() != params_from_tool_call.len() - { + + if !run_params.is_empty() && run_params.len() != params_from_tool_call.len() { return Err(anyhow::anyhow!( - "The number of runs in the sub recipe does not match the number of task parameters" + "The number of runs in the sub recipe ({}) does not match the number of task parameters ({})", + run_params.len(), + params_from_tool_call.len() )); } - let mut sub_recipe_run_params = vec![sub_recipe_params.clone(); params_from_tool_call.len()]; - for (index, sub_recipe_task_param) in params_from_tool_call.iter().enumerate() { - if let Some(params_with_value) = sub_recipe_task_param.as_object() { - for (key, value) in params_with_value { - sub_recipe_run_params[index] - .entry(key.to_string()) - .or_insert_with(|| value.as_str().unwrap_or(&value.to_string()).to_string()); - } - } - } - Ok(sub_recipe_run_params) -} -#[allow(dead_code)] -pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result { - let command_params = prepare_command_params(sub_recipe, params)?; - let payload = json!({ - "sub_recipe": { - "name": sub_recipe.name.clone(), - "command_parameters": command_params, - "recipe_path": sub_recipe.path.clone(), - } - }); - let task = Task { - id: uuid::Uuid::new_v4().to_string(), - task_type: "sub_recipe".to_string(), - payload, + let run_params_for_merging = if run_params.is_empty() { + vec![base_params; params_from_tool_call.len()] + } else { + run_params }; - let task_json = serde_json::to_string(&task) - .map_err(|e| anyhow::anyhow!("Failed to serialize Task: {}", e))?; - Ok(task_json) -} -pub async fn create_multiple_sub_recipe_tasks( - sub_recipe: &SubRecipe, - params: Value, -) -> Result { - // Get the task_parameters array - let empty_vec = vec![]; - let task_params_array = params - .get("task_parameters") - .and_then(|v| v.as_array()) - .unwrap_or(&empty_vec); - let command_params = - prepare_command_params_for_multiple_sub_recipe_task(sub_recipe, task_params_array.clone())?; - let tasks = command_params - .iter() - .map(|task_command_param| { - let payload = json!({ - "sub_recipe": { - "name": sub_recipe.name.clone(), - "command_parameters": task_command_param, - "recipe_path": sub_recipe.path.clone(), + let merged_params = params_from_tool_call + .into_iter() + .zip(run_params_for_merging) + .map(|(tool_param, mut run_param_map)| { + if let Some(param_obj) = tool_param.as_object() { + for (key, value) in param_obj { + let value_str = value + .as_str() + .map(String::from) + .unwrap_or_else(|| value.to_string()); + run_param_map.entry(key.clone()).or_insert(value_str); } - }); - Task { - id: uuid::Uuid::new_v4().to_string(), - task_type: "sub_recipe".to_string(), - payload, } + run_param_map }) - .collect::>(); - let is_parallel = sub_recipe - .executions - .as_ref() - .map(|e| e.parallel) - .unwrap_or(false); - let task_execution_payload = json!({ - "tasks": tasks, - "execution_mode": if is_parallel { "parallel" } else { "sequential" } - }); - println!("===== task_execution_payload: {:?}", task_execution_payload); + .collect(); - let tasks_json = serde_json::to_string(&task_execution_payload) - .map_err(|e| anyhow::anyhow!("Failed to serialize task list: {}", e))?; - Ok(tasks_json) + Ok(merged_params) } #[cfg(test)] diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index 4fa8574d5b13..fac6a095bbfe 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -2,9 +2,12 @@ mod tests { use std::collections::HashMap; - use crate::recipe::SubRecipe; + use crate::recipe::{Execution, ExecutionRun, SubRecipe}; + use serde_json::json; + use serde_json::Value; + use tempfile::TempDir; - fn setup_sub_recipe() -> SubRecipe { + fn setup_default_sub_recipe() -> SubRecipe { let sub_recipe = SubRecipe { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), @@ -13,58 +16,225 @@ mod tests { }; sub_recipe } - mod prepare_command_params_tests { - use std::collections::HashMap; - use crate::{ - agents::recipe_tools::sub_recipe_tools::{ - prepare_command_params, tests::tests::setup_sub_recipe, - }, - recipe::SubRecipe, - }; + fn create_execution_values(key: &str, values: Vec) -> Execution { + let runs = values + .iter() + .map(|value| ExecutionRun { + values: Some(HashMap::from([(key.to_string(), value.to_string())])), + }) + .collect(); + Execution { + parallel: true, + runs: Some(runs), + } + } - #[test] - fn test_prepare_command_params_basic() { - let mut params = HashMap::new(); - params.insert("key2".to_string(), "value2".to_string()); + mod prepare_command_params_tests { + use super::*; - let sub_recipe = setup_sub_recipe(); + use crate::agents::recipe_tools::sub_recipe_tools::{ + prepare_command_params, tests::tests::setup_default_sub_recipe, + }; - let params_value = serde_json::to_value(params).unwrap(); - let result = prepare_command_params(&sub_recipe, params_value).unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result.get("key1"), Some(&"value1".to_string())); - assert_eq!(result.get("key2"), Some(&"value2".to_string())); + mod without_execution_runs { + use super::*; + + #[test] + fn test_return_command_param() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "value2".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_command_param_when_value_override_passed_param_value() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "different_value".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_empty_command_param() { + let parameter_array = vec![]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!(result.len(), 0); + } } - #[test] - fn test_prepare_command_params_empty() { - let sub_recipe = SubRecipe { - name: "test_sub_recipe".to_string(), - path: "test_sub_recipe.yaml".to_string(), - values: None, - executions: None, - }; - let params: HashMap = HashMap::new(); - let params_value = serde_json::to_value(params).unwrap(); - let result = prepare_command_params(&sub_recipe, params_value).unwrap(); - assert_eq!(result.len(), 0); + mod with_execution_runs { + use super::*; + + #[test] + fn test_return_command_param() { + let parameter_array = vec![json!(HashMap::from([( + "key3".to_string(), + "value3".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = + Some(create_execution_values("key2", vec!["value2".to_string()])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ("key3".to_string(), "value3".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_command_param_when_all_values_from_tool_call_parameters() { + let parameter_array = vec![ + json!(HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()) + ])), + json!(HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()) + ])), + ]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + sub_recipe.executions = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ]), + ], + result + ); + } + + #[test] + fn test_return_command_param_when_all_from_values_in_sub_recipe() { + let parameter_array = vec![]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string(), "key2_value2".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ("key3".to_string(), "value3".to_string()), + ]) + ], + result + ); + } + + #[test] + fn test_throw_error_when_execution_runs_value_length_not_match_with_tool_call_parameters( + ) { + let parameter_array = vec![json!(HashMap::from([( + "key3".to_string(), + "value3".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string(), "key2_value2".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array); + + assert!(result.is_err()); + } } } - mod get_input_schema_tests { - use crate::{ - agents::recipe_tools::sub_recipe_tools::{ - get_input_schema, tests::tests::setup_sub_recipe, - }, - recipe::SubRecipe, - }; + mod get_input_schema { + use super::*; + use crate::agents::recipe_tools::sub_recipe_tools::get_input_schema; + + fn prepare_sub_recipe(sub_recipe_file_content: &str) -> (SubRecipe, TempDir) { + let mut sub_recipe = setup_default_sub_recipe(); + let temp_dir = tempfile::tempdir().unwrap(); + let temp_file = temp_dir.path().join(sub_recipe.path.clone()); + std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); + sub_recipe.path = temp_file.to_string_lossy().to_string(); + (sub_recipe, temp_dir) + } - #[test] - fn test_get_input_schema_with_parameters() { - let sub_recipe = setup_sub_recipe(); + fn verify_task_parameters(result: Value, expected_task_parameters_items: Value) { + let task_parameters = result + .get("properties") + .unwrap() + .as_object() + .unwrap() + .get("task_parameters") + .unwrap() + .as_object() + .unwrap(); + let task_parameters_items = task_parameters.get("items").unwrap(); + assert_eq!(&expected_task_parameters_items, task_parameters_items); + } + + mod without_execution_runs { + use super::*; - let sub_recipe_file_content = r#"{ + const SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS: &str = r#"{ "version": "1.0.0", "title": "Test Recipe", "description": "A test recipe", @@ -85,40 +255,75 @@ mod tests { ] }"#; - let temp_dir = tempfile::tempdir().unwrap(); - let temp_file = temp_dir.path().join("test_sub_recipe.yaml"); - std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); - - let mut sub_recipe = sub_recipe; - sub_recipe.path = temp_file.to_string_lossy().to_string(); - - let result = get_input_schema(&sub_recipe).unwrap(); - - // Verify the schema structure - assert_eq!(result["type"], "object"); - assert!(result["properties"].is_object()); - - let properties = result["properties"].as_object().unwrap(); - assert_eq!(properties.len(), 1); - - let key2_prop = &properties["key2"]; - assert_eq!(key2_prop["type"], "number"); - assert_eq!(key2_prop["description"], "An optional parameter"); - - let required = result["required"].as_array().unwrap(); - assert_eq!(required.len(), 0); + #[test] + fn test_with_one_param_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + + let result = get_input_schema(&sub_recipe).unwrap(); + + verify_task_parameters( + result, + json!({ + "type": "object", + "properties": { + "key2": { "type": "number", "description": "An optional parameter" } + }, + "required": [] + }), + ); + } + + #[test] + fn test_without_param_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])); + + let result = get_input_schema(&sub_recipe).unwrap(); + + assert_eq!( + None, + result + .get("properties") + .unwrap() + .as_object() + .unwrap() + .get("task_parameters") + ); + } + + #[test] + fn test_with_all_params_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); + sub_recipe.values = None; + + let result = get_input_schema(&sub_recipe).unwrap(); + + verify_task_parameters( + result, + json!({ + "type": "object", + "properties": { + "key1": { "type": "string", "description": "A test parameter" }, + "key2": { "type": "number", "description": "An optional parameter" } + }, + "required": ["key1"] + }), + ); + } } - #[test] - fn test_get_input_schema_no_parameters_values() { - let sub_recipe = SubRecipe { - name: "test_sub_recipe".to_string(), - path: "test_sub_recipe.yaml".to_string(), - values: None, - executions: None, - }; + mod execution_runs { + use super::*; - let sub_recipe_file_content = r#"{ + const SUB_RECIPE_FILE_CONTENT_WITH_THREE_PARAMS: &str = r#"{ "version": "1.0.0", "title": "Test Recipe", "description": "A test recipe", @@ -128,102 +333,95 @@ mod tests { "key": "key1", "input_type": "string", "requirement": "required", - "description": "A test parameter" - } - ] - }"#; - - let temp_dir = tempfile::tempdir().unwrap(); - let temp_file = temp_dir.path().join("test_sub_recipe.yaml"); - std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); - - let mut sub_recipe = sub_recipe; - sub_recipe.path = temp_file.to_string_lossy().to_string(); - - let result = get_input_schema(&sub_recipe).unwrap(); - - assert_eq!(result["type"], "object"); - assert!(result["properties"].is_object()); - - let properties = result["properties"].as_object().unwrap(); - assert_eq!(properties.len(), 1); - - let key1_prop = &properties["key1"]; - assert_eq!(key1_prop["type"], "string"); - assert_eq!(key1_prop["description"], "A test parameter"); - assert_eq!(result["required"].as_array().unwrap().len(), 1); - assert_eq!(result["required"][0], "key1"); - } - } - - mod get_input_schema_for_multiple_sub_recipe_task_tests { - use crate::{ - agents::recipe_tools::sub_recipe_tools::get_input_schema_for_multiple_sub_recipe_task, - recipe::SubRecipe, - }; - - #[test] - fn test_get_input_schema_for_multiple_tasks() { - let sub_recipe_file_content = r#"{ - "version": "1.0.0", - "title": "Test Recipe", - "description": "A test recipe", - "prompt": "Test prompt", - "parameters": [ - { - "key": "param1", - "input_type": "string", - "requirement": "required", - "description": "A required parameter" + "description": "A required string parameter" }, { - "key": "param2", + "key": "key2", "input_type": "number", "requirement": "optional", "description": "An optional parameter" + }, + { + "key": "key3", + "input_type": "date", + "requirement": "required", + "description": "A required date parameter" } ] }"#; - let temp_dir = tempfile::tempdir().unwrap(); - let temp_file = temp_dir.path().join("test_sub_recipe.yaml"); - std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); - - let sub_recipe = SubRecipe { - name: "test_sub_recipe".to_string(), - path: temp_file.to_string_lossy().to_string(), - values: None, - executions: None, - }; - - let result = get_input_schema_for_multiple_sub_recipe_task(&sub_recipe).unwrap(); - - // Verify the schema structure - assert_eq!(result["type"], "object"); - assert!(result["properties"].is_object()); - - let properties = result["properties"].as_object().unwrap(); - assert_eq!(properties.len(), 1); - assert!(properties.contains_key("task_parameters")); - - let task_params = &properties["task_parameters"]; - assert_eq!(task_params["type"], "array"); - assert_eq!(task_params["minItems"], 1); - - let items = &task_params["items"]; - assert_eq!(items["type"], "object"); - - let item_properties = items["properties"].as_object().unwrap(); - assert_eq!(item_properties.len(), 2); - assert!(item_properties.contains_key("param1")); - assert!(item_properties.contains_key("param2")); - - assert_eq!(item_properties["param1"]["type"], "string"); - assert_eq!(item_properties["param2"]["type"], "number"); - - let required = items["required"].as_array().unwrap(); - assert_eq!(required.len(), 1); - assert_eq!(required[0], "param1"); + #[test] + fn test_with_one_param_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_THREE_PARAMS); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value_1".to_string(), "key2_value_2".to_string()], + )); + + let result = get_input_schema(&sub_recipe).unwrap(); + + verify_task_parameters( + result, + json!({ + "type": "object", + "properties": { + "key3": { "type": "date", "description": "A required date parameter" } + }, + "required": ["key3"] + }), + ); + } + + #[test] + fn test_without_param_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_THREE_PARAMS); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value_1".to_string(), "key2_value_2".to_string()], + )); + + let result = get_input_schema(&sub_recipe).unwrap(); + + assert_eq!( + None, + result + .get("properties") + .unwrap() + .as_object() + .unwrap() + .get("task_parameters") + ); + } + + #[test] + fn test_with_all_params_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_THREE_PARAMS); + sub_recipe.values = None; + + let result = get_input_schema(&sub_recipe).unwrap(); + + verify_task_parameters( + result, + json!({ + "type": "object", + "properties": { + "key1": { "type": "string", "description": "A required string parameter" }, + "key2": { "type": "number", "description": "An optional parameter" }, + "key3": { "type": "date", "description": "A required date parameter" } + }, + "required": ["key1", "key3"] + }), + ); + } } } } diff --git a/crates/goose/src/agents/sub_recipe_manager.rs b/crates/goose/src/agents/sub_recipe_manager.rs index 76d7effb48ee..d759914a9d9e 100644 --- a/crates/goose/src/agents/sub_recipe_manager.rs +++ b/crates/goose/src/agents/sub_recipe_manager.rs @@ -5,8 +5,7 @@ use std::collections::HashMap; use crate::{ agents::{ recipe_tools::sub_recipe_tools::{ - create_multiple_sub_recipe_task_tool, create_multiple_sub_recipe_tasks, - SUB_RECIPE_TASK_TOOL_NAME_PREFIX, + create_sub_recipe_task, create_sub_recipe_task_tool, SUB_RECIPE_TASK_TOOL_NAME_PREFIX, }, tool_execution::ToolCallResult, }, @@ -40,7 +39,7 @@ impl SubRecipeManager { SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name.clone() ); - let tool = create_multiple_sub_recipe_task_tool(&sub_recipe); + let tool = create_sub_recipe_task_tool(&sub_recipe); self.sub_recipe_tools.insert(sub_recipe_key.clone(), tool); self.sub_recipes.insert(sub_recipe_key.clone(), sub_recipe); } @@ -106,7 +105,7 @@ impl SubRecipeManager { ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name)) })?; - let output = create_multiple_sub_recipe_tasks(sub_recipe, params) + let output = create_sub_recipe_task(sub_recipe, params) .await .map_err(|e| { ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e)) From 476651ffcb4418e320386b9bee25b7c460800683 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 8 Jul 2025 14:13:09 +1000 Subject: [PATCH 03/43] get sub recipe json output --- .../agents/sub_recipe_execution_tool/tasks.rs | 89 ++++++++++++------- 1 file changed, 57 insertions(+), 32 deletions(-) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index 4e4584aa0b34..b934a13344fc 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -7,12 +7,10 @@ use tokio::time::timeout; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult}; -// Process a single task based on its type pub async fn process_task(task: &Task, timeout_seconds: u64) -> TaskResult { let task_clone = task.clone(); let timeout_duration = Duration::from_secs(timeout_seconds); - // Execute with timeout match timeout(timeout_duration, execute_task(task_clone)).await { Ok(Ok(data)) => TaskResult { task_id: task.id.clone(), @@ -36,6 +34,17 @@ pub async fn process_task(task: &Task, timeout_seconds: u64) -> TaskResult { } async fn execute_task(task: Task) -> Result { + let (command, output_identifier) = build_command(&task)?; + let (stdout_output, stderr_output, success) = run_command(command, &output_identifier).await?; + + if success { + process_output(stdout_output) + } else { + Err(format!("Command failed:\n{}", stderr_output)) + } +} + +fn build_command(task: &Task) -> Result<(Command, String), String> { let mut output_identifier = task.id.clone(); let mut command = if task.task_type == "sub_recipe" { let sub_recipe = task.payload.get("sub_recipe").unwrap(); @@ -66,11 +75,12 @@ async fn execute_task(task: Task) -> Result { cmd }; - // Configure to capture stdout command.stdout(Stdio::piped()); command.stderr(Stdio::piped()); + Ok((command, output_identifier)) +} - // Spawn the child process +async fn run_command(mut command: Command, output_identifier: &str) -> Result<(String, String, bool), String> { let mut child = command .spawn() .map_err(|e| format!("Failed to spawn goose: {}", e))?; @@ -78,30 +88,8 @@ async fn execute_task(task: Task) -> Result { let stdout = child.stdout.take().expect("Failed to capture stdout"); let stderr = child.stderr.take().expect("Failed to capture stderr"); - let mut stdout_reader = BufReader::new(stdout).lines(); - let mut stderr_reader = BufReader::new(stderr).lines(); - - // Spawn background tasks to read from stdout and stderr - let output_identifier_clone = output_identifier.clone(); - let stdout_task = tokio::spawn(async move { - let mut buffer = String::new(); - while let Ok(Some(line)) = stdout_reader.next_line().await { - println!("[{}] {}", output_identifier_clone, line); - buffer.push_str(&line); - buffer.push('\n'); - } - buffer - }); - - let stderr_task = tokio::spawn(async move { - let mut buffer = String::new(); - while let Ok(Some(line)) = stderr_reader.next_line().await { - eprintln!("[stderr for {}] {}", output_identifier, line); - buffer.push_str(&line); - buffer.push('\n'); - } - buffer - }); + let stdout_task = spawn_output_reader(stdout, output_identifier, false); + let stderr_task = spawn_output_reader(stderr, output_identifier, true); let status = child .wait() @@ -111,9 +99,46 @@ async fn execute_task(task: Task) -> Result { let stdout_output = stdout_task.await.unwrap(); let stderr_output = stderr_task.await.unwrap(); - if status.success() { - Ok(Value::String(stdout_output)) - } else { - Err(format!("Command failed:\n{}", stderr_output)) + Ok((stdout_output, stderr_output, status.success())) +} + +fn spawn_output_reader( + reader: impl tokio::io::AsyncRead + Unpin + Send + 'static, + output_identifier: &str, + is_stderr: bool, +) -> tokio::task::JoinHandle { + let output_identifier = output_identifier.to_string(); + tokio::spawn(async move { + let mut buffer = String::new(); + let mut lines = BufReader::new(reader).lines(); + while let Ok(Some(line)) = lines.next_line().await { + if is_stderr { + eprintln!("[stderr for {}] {}", output_identifier, line); + } else { + println!("[{}] {}", output_identifier, line); + } + buffer.push_str(&line); + buffer.push('\n'); + } + buffer + }) +} + +fn process_output(stdout_output: String) -> Result { + let last_line = stdout_output + .lines() + .filter(|line| !line.trim().is_empty()) + .last() + .unwrap_or(""); + + if let (Some(start), Some(end)) = (last_line.find('{'), last_line.rfind('}')) { + if start < end { + let potential_json = &last_line[start..=end]; + + if serde_json::from_str::(potential_json).is_ok() { + return Ok(Value::String(potential_json.to_string())); + } + } } + Ok(Value::String(stdout_output)) } From b50d245abe98fc40a0efcbe123148fae231c6f91 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 8 Jul 2025 16:53:07 +1000 Subject: [PATCH 04/43] added showing parallel running dashboard --- .../sub_recipe_execution_tool/dashboard.rs | 192 +++++++++++++ .../sub_recipe_execution_tool/executor.rs | 223 +++++++++++---- .../agents/sub_recipe_execution_tool/lib.rs | 2 +- .../agents/sub_recipe_execution_tool/mod.rs | 2 + .../agents/sub_recipe_execution_tool/tasks.rs | 61 ++-- .../agents/sub_recipe_execution_tool/types.rs | 57 +++- .../sub_recipe_execution_tool/utils/mod.rs | 65 +++++ .../sub_recipe_execution_tool/utils/tests.rs | 262 ++++++++++++++++++ .../sub_recipe_execution_tool/workers.rs | 105 ++----- 9 files changed, 808 insertions(+), 161 deletions(-) create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs new file mode 100644 index 000000000000..68867bb8e5c8 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs @@ -0,0 +1,192 @@ +use std::collections::HashMap; +use std::io::{self, Write}; +use std::sync::Arc; +use tokio::sync::RwLock; +use tokio::time::{Duration, Instant}; + +use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskResult, TaskStatus}; +use crate::agents::sub_recipe_execution_tool::utils::{ + count_by_status, get_task_name, strip_ansi_codes, truncate_with_ellipsis, +}; + +pub struct TaskDashboard { + tasks: Arc>>, + last_display: Arc>, + last_refresh: Arc>, +} + +impl TaskDashboard { + pub fn new(tasks: Vec) -> Self { + let task_map = tasks + .into_iter() + .map(|task| { + let task_id = task.id.clone(); + ( + task_id, + TaskInfo { + task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }, + ) + }) + .collect(); + + Self { + tasks: Arc::new(RwLock::new(task_map)), + last_display: Arc::new(RwLock::new(String::new())), + last_refresh: Arc::new(RwLock::new(Instant::now())), + } + } + + pub async fn start_task(&self, task_id: &str) { + let mut tasks = self.tasks.write().await; + if let Some(task_info) = tasks.get_mut(task_id) { + task_info.status = TaskStatus::Running; + task_info.start_time = Some(Instant::now()); + } + drop(tasks); + self.refresh_display().await; + } + + pub async fn complete_task(&self, task_id: &str, result: TaskResult) { + let mut tasks = self.tasks.write().await; + if let Some(task_info) = tasks.get_mut(task_id) { + task_info.status = result.status.clone(); + task_info.end_time = Some(Instant::now()); + task_info.result = Some(result); + } + drop(tasks); + self.refresh_display().await; + } + + pub async fn update_task_output(&self, task_id: &str, output: &str) { + let mut tasks = self.tasks.write().await; + if let Some(task_info) = tasks.get_mut(task_id) { + // Keep only the last few lines to avoid overwhelming display + let lines: Vec<&str> = output.lines().collect(); + let recent_lines = if lines.len() > 2 { + &lines[lines.len() - 2..] + } else { + &lines + }; + + // Strip ANSI escape sequences to prevent color flashing + let clean_output = recent_lines.join("\n"); + task_info.current_output = strip_ansi_codes(&clean_output); + } + drop(tasks); + + // Throttle refreshes to avoid overwhelming the display (max 1 per second) + let now = Instant::now(); + let mut last_refresh = self.last_refresh.write().await; + if now.duration_since(*last_refresh) > Duration::from_millis(1000) { + *last_refresh = now; + drop(last_refresh); + self.refresh_display().await; + } + } + + pub async fn refresh_display(&self) { + let tasks = self.tasks.read().await; + let mut display = String::new(); + + // Clear screen and move to top + display.push_str("\x1b[2J\x1b[H"); + + // Title + display.push_str("🎯 Task Execution Dashboard\n"); + display.push_str("═══════════════════════════\n\n"); + + // Summary stats + let (total, pending, running, completed, failed) = count_by_status(&tasks); + + display.push_str(&format!("📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed\n\n", + total, pending, running, completed, failed)); + + // Task list + let mut task_list: Vec<_> = tasks.values().collect(); + task_list.sort_by_key(|t| &t.task.id); + + for task_info in task_list { + let status_icon = match task_info.status { + TaskStatus::Pending => "⏳", + TaskStatus::Running => "🏃", + TaskStatus::Completed => "✅", + TaskStatus::Failed => "❌", + }; + + let task_name = get_task_name(task_info); + + display.push_str(&format!( + "{} {} ({})\n", + status_icon, task_name, task_info.task.task_type + )); + + if let Some(start_time) = task_info.start_time { + let duration = if let Some(end_time) = task_info.end_time { + end_time.duration_since(start_time) + } else { + Instant::now().duration_since(start_time) + }; + display.push_str(&format!(" ⏱️ {:.1}s\n", duration.as_secs_f64())); + } + + if matches!(task_info.status, TaskStatus::Running) + && !task_info.current_output.is_empty() + { + let output_preview = truncate_with_ellipsis(&task_info.current_output, 100); + display.push_str(&format!(" 💬 {}\n", output_preview.replace('\n', " | "))); + } + + if let Some(error) = task_info.error() { + let error_preview = truncate_with_ellipsis(error, 80); + display.push_str(&format!(" ⚠️ {}\n", error_preview.replace('\n', " "))); + } + + display.push('\n'); + } + + // Only update display if it changed + let mut last_display = self.last_display.write().await; + if *last_display != display { + print!("{}", display); + io::stdout().flush().unwrap(); + *last_display = display; + } + } + + pub async fn show_final_summary(&self) { + let tasks = self.tasks.read().await; + + println!("\n🎉 Execution Complete!"); + println!("═══════════════════════"); + + let (total, _, _, completed, failed) = count_by_status(&tasks); + + println!("📊 Final Results:"); + println!(" Total Tasks: {}", total); + println!(" ✅ Completed: {}", completed); + println!(" ❌ Failed: {}", failed); + println!( + " 📈 Success Rate: {:.1}%", + (completed as f64 / total as f64) * 100.0 + ); + + if failed > 0 { + println!("\n❌ Failed Tasks:"); + for task_info in tasks.values() { + if matches!(task_info.status, TaskStatus::Failed) { + let task_name = get_task_name(task_info); + println!(" • {}", task_name); + if let Some(error) = task_info.error() { + println!(" Error: {}", error); + } + } + } + } + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs index b796d412984d..9d869666e4af 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs @@ -1,103 +1,208 @@ -use std::sync::atomic::{AtomicBool, AtomicUsize}; +use std::sync::atomic::AtomicUsize; use std::sync::Arc; use tokio::sync::mpsc; use tokio::time::Instant; +use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; use crate::agents::sub_recipe_execution_tool::lib::{ - Config, ExecutionResponse, ExecutionStats, Task, TaskResult, + Config, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; use crate::agents::sub_recipe_execution_tool::tasks::process_task; -use crate::agents::sub_recipe_execution_tool::workers::{run_scaler, spawn_worker, SharedState}; +use crate::agents::sub_recipe_execution_tool::workers::spawn_worker; + +const EXECUTION_STATUS_COMPLETED: &str = "completed"; pub async fn execute_single_task(task: &Task, config: Config) -> ExecutionResponse { let start_time = Instant::now(); let result = process_task(task, config.timeout_seconds).await; let execution_time = start_time.elapsed().as_millis(); - let completed = if result.status == "success" { 1 } else { 0 }; - let failed = if result.status == "failed" { 1 } else { 0 }; + let stats = calculate_stats(&[result.clone()], execution_time); ExecutionResponse { - status: "completed".to_string(), + status: EXECUTION_STATUS_COMPLETED.to_string(), results: vec![result], - stats: ExecutionStats { - total_tasks: 1, - completed, - failed, - execution_time_ms: execution_time, - }, + stats, } } -// Main parallel execution function -pub async fn parallel_execute(tasks: Vec, config: Config) -> ExecutionResponse { - let start_time = Instant::now(); - let task_count = tasks.len(); +fn calculate_stats(results: &[TaskResult], execution_time_ms: u128) -> ExecutionStats { + let completed = results + .iter() + .filter(|r| matches!(r.status, TaskStatus::Completed)) + .count(); + let failed = results + .iter() + .filter(|r| matches!(r.status, TaskStatus::Failed)) + .count(); + + ExecutionStats { + total_tasks: results.len(), + completed, + failed, + execution_time_ms, + } +} + +struct ExecutionContext { + tasks: Vec, + config: Config, + dashboard: Arc, + start_time: Instant, +} + +impl ExecutionContext { + fn new(tasks: Vec, config: Config) -> Self { + let dashboard = Arc::new(TaskDashboard::new(tasks.clone())); + + Self { + tasks, + config, + dashboard, + start_time: Instant::now(), + } + } - // Create channels + fn task_count(&self) -> usize { + self.tasks.len() + } +} + +fn create_channels( + task_count: usize, +) -> ( + mpsc::Sender, + mpsc::Receiver, + mpsc::Sender, + mpsc::Receiver, +) { let (task_tx, task_rx) = mpsc::channel::(task_count); - let (result_tx, mut result_rx) = mpsc::channel::(task_count); + let (result_tx, result_rx) = mpsc::channel::(task_count); + (task_tx, task_rx, result_tx, result_rx) +} - // Initialize shared state - let shared_state = Arc::new(SharedState { +fn create_shared_state( + task_rx: mpsc::Receiver, + result_tx: mpsc::Sender, + dashboard: Arc, +) -> Arc { + Arc::new(SharedState { task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), result_sender: result_tx, active_workers: Arc::new(AtomicUsize::new(0)), - should_stop: Arc::new(AtomicBool::new(false)), - completed_tasks: Arc::new(AtomicUsize::new(0)), - }); + dashboard: Some(dashboard), + }) +} - // Send all tasks to the queue - for task in tasks.clone() { - let _ = task_tx.send(task).await; +async fn send_tasks_to_channel( + tasks: Vec, + task_tx: mpsc::Sender, +) -> Result<(), String> { + for task in tasks { + task_tx + .send(task) + .await + .map_err(|e| format!("Failed to queue task: {}", e))?; } - // Close sender so workers know when queue is empty - drop(task_tx); + Ok(()) +} - // Start initial workers - let mut worker_handles = Vec::new(); - for i in 0..config.initial_workers { - let handle = spawn_worker(shared_state.clone(), i, config.timeout_seconds); - worker_handles.push(handle); +fn create_empty_response() -> ExecutionResponse { + ExecutionResponse { + status: EXECUTION_STATUS_COMPLETED.to_string(), + results: vec![], + stats: ExecutionStats { + total_tasks: 0, + completed: 0, + failed: 0, + execution_time_ms: 0, + }, } +} - // Start the scaler - let scaler_state = shared_state.clone(); - let scaler_handle = tokio::spawn(async move { - run_scaler( - scaler_state, - task_count, - config.max_workers, - config.timeout_seconds, - ) - .await; - }); - - // Collect results +async fn collect_results( + result_rx: &mut mpsc::Receiver, + dashboard: Arc, + expected_count: usize, +) -> Vec { let mut results = Vec::new(); while let Some(result) = result_rx.recv().await { + dashboard + .complete_task(&result.task_id, result.clone()) + .await; results.push(result); - if results.len() >= task_count { + if results.len() >= expected_count { break; } } + results +} - // Wait for scaler to finish - let _ = scaler_handle.await; +async fn execute_with_context(ctx: ExecutionContext) -> Result { + let task_count = ctx.task_count(); - // Calculate stats - let execution_time = start_time.elapsed().as_millis(); - let completed = results.iter().filter(|r| r.status == "success").count(); - let failed = results.iter().filter(|r| r.status == "failed").count(); + if task_count == 0 { + return Ok(create_empty_response()); + } - ExecutionResponse { - status: "completed".to_string(), + ctx.dashboard.refresh_display().await; + + let (task_tx, task_rx, result_tx, mut result_rx) = create_channels(task_count); + + send_tasks_to_channel(ctx.tasks, task_tx).await?; + + let shared_state = create_shared_state(task_rx, result_tx, ctx.dashboard.clone()); + + // Simple static worker allocation - no dynamic scaling needed + let worker_count = std::cmp::min(task_count, ctx.config.max_workers); + let mut worker_handles = Vec::new(); + for i in 0..worker_count { + let handle = spawn_worker(shared_state.clone(), i, ctx.config.timeout_seconds); + worker_handles.push(handle); + } + + let results = collect_results(&mut result_rx, ctx.dashboard.clone(), task_count).await; + + // Wait for all workers to finish + for handle in worker_handles { + if let Err(e) = handle.await { + eprintln!("Worker error: {}", e); + } + } + + ctx.dashboard.show_final_summary().await; + + let execution_time = ctx.start_time.elapsed().as_millis(); + let stats = calculate_stats(&results, execution_time); + + Ok(ExecutionResponse { + status: EXECUTION_STATUS_COMPLETED.to_string(), results, + stats, + }) +} + +fn create_error_response(_error: String) -> ExecutionResponse { + ExecutionResponse { + status: "failed".to_string(), + results: vec![], stats: ExecutionStats { - total_tasks: task_count, - completed, - failed, - execution_time_ms: execution_time, + total_tasks: 0, + completed: 0, + failed: 1, + execution_time_ms: 0, }, } } + +pub async fn parallel_execute(tasks: Vec, config: Config) -> ExecutionResponse { + let ctx = ExecutionContext::new(tasks, config); + + match execute_with_context(ctx).await { + Ok(response) => response, + Err(e) => { + eprintln!("Execution failed: {}", e); + create_error_response(e) + } + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs index 9df784a46be0..6c4718c2b3a0 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs @@ -1,7 +1,7 @@ use crate::agents::sub_recipe_execution_tool::executor::execute_single_task; pub use crate::agents::sub_recipe_execution_tool::executor::parallel_execute; pub use crate::agents::sub_recipe_execution_tool::types::{ - Config, ExecutionResponse, ExecutionStats, Task, TaskResult, + Config, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; use serde_json::Value; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs index a49791e2776f..6f131862fefd 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs @@ -1,6 +1,8 @@ +mod dashboard; mod executor; pub mod lib; pub mod sub_recipe_execute_task_tool; mod tasks; mod types; +pub mod utils; mod workers; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index b934a13344fc..9543f0cec504 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -1,42 +1,53 @@ use serde_json::Value; use std::process::Stdio; +use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; use tokio::time::timeout; -use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult}; +use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; +use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult, TaskStatus}; pub async fn process_task(task: &Task, timeout_seconds: u64) -> TaskResult { + process_task_with_dashboard(task, timeout_seconds, None).await +} + +pub async fn process_task_with_dashboard( + task: &Task, + timeout_seconds: u64, + dashboard: Option>, +) -> TaskResult { let task_clone = task.clone(); let timeout_duration = Duration::from_secs(timeout_seconds); - match timeout(timeout_duration, execute_task(task_clone)).await { + match timeout(timeout_duration, execute_task(task_clone, dashboard)).await { Ok(Ok(data)) => TaskResult { task_id: task.id.clone(), - status: "success".to_string(), + status: TaskStatus::Completed, data: Some(data), error: None, }, Ok(Err(error)) => TaskResult { task_id: task.id.clone(), - status: "failed".to_string(), + status: TaskStatus::Failed, data: None, error: Some(error), }, Err(_) => TaskResult { task_id: task.id.clone(), - status: "failed".to_string(), + status: TaskStatus::Failed, data: None, error: Some("Task timeout".to_string()), }, } } -async fn execute_task(task: Task) -> Result { +async fn execute_task(task: Task, dashboard: Option>) -> Result { let (command, output_identifier) = build_command(&task)?; - let (stdout_output, stderr_output, success) = run_command(command, &output_identifier).await?; - + let (stdout_output, stderr_output, success) = + run_command(command, &output_identifier, &task.id, dashboard).await?; + if success { process_output(stdout_output) } else { @@ -80,7 +91,12 @@ fn build_command(task: &Task) -> Result<(Command, String), String> { Ok((command, output_identifier)) } -async fn run_command(mut command: Command, output_identifier: &str) -> Result<(String, String, bool), String> { +async fn run_command( + mut command: Command, + output_identifier: &str, + task_id: &str, + dashboard: Option>, +) -> Result<(String, String, bool), String> { let mut child = command .spawn() .map_err(|e| format!("Failed to spawn goose: {}", e))?; @@ -88,8 +104,9 @@ async fn run_command(mut command: Command, output_identifier: &str) -> Result<(S let stdout = child.stdout.take().expect("Failed to capture stdout"); let stderr = child.stderr.take().expect("Failed to capture stderr"); - let stdout_task = spawn_output_reader(stdout, output_identifier, false); - let stderr_task = spawn_output_reader(stderr, output_identifier, true); + let stdout_task = + spawn_output_reader(stdout, output_identifier, false, task_id, dashboard.clone()); + let stderr_task = spawn_output_reader(stderr, output_identifier, true, task_id, None); let status = child .wait() @@ -106,19 +123,25 @@ fn spawn_output_reader( reader: impl tokio::io::AsyncRead + Unpin + Send + 'static, output_identifier: &str, is_stderr: bool, + task_id: &str, + dashboard: Option>, ) -> tokio::task::JoinHandle { let output_identifier = output_identifier.to_string(); + let task_id = task_id.to_string(); tokio::spawn(async move { let mut buffer = String::new(); let mut lines = BufReader::new(reader).lines(); while let Ok(Some(line)) = lines.next_line().await { - if is_stderr { - eprintln!("[stderr for {}] {}", output_identifier, line); - } else { - println!("[{}] {}", output_identifier, line); - } buffer.push_str(&line); buffer.push('\n'); + + if !is_stderr { + if let Some(dashboard) = &dashboard { + dashboard.update_task_output(&task_id, &buffer).await; + } + } else { + eprintln!("[stderr for {}] {}", output_identifier, line); + } } buffer }) @@ -128,13 +151,13 @@ fn process_output(stdout_output: String) -> Result { let last_line = stdout_output .lines() .filter(|line| !line.trim().is_empty()) - .last() + .next_back() .unwrap_or(""); - + if let (Some(start), Some(end)) = (last_line.find('{'), last_line.rfind('}')) { if start < end { let potential_json = &last_line[start..=end]; - + if serde_json::from_str::(potential_json).is_ok() { return Ok(Value::String(potential_json.to_string())); } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index ede71dbf40b4..2f40490a564c 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -1,7 +1,11 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::mpsc; + +use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; -// Task definition that LLMs will send #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Task { pub id: String, @@ -9,18 +13,61 @@ pub struct Task { pub payload: Value, } -// Result for each task #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TaskResult { pub task_id: String, - pub status: String, + pub status: TaskStatus, #[serde(skip_serializing_if = "Option::is_none")] pub data: Option, #[serde(skip_serializing_if = "Option::is_none")] pub error: Option, } -// Configuration for the parallel executor +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum TaskStatus { + Pending, + Running, + Completed, + Failed, +} + +#[derive(Debug, Clone)] +pub struct TaskInfo { + pub task: Task, + pub status: TaskStatus, + pub start_time: Option, + pub end_time: Option, + pub result: Option, + pub current_output: String, +} + +impl TaskInfo { + pub fn error(&self) -> Option<&String> { + self.result.as_ref().and_then(|r| r.error.as_ref()) + } + + pub fn data(&self) -> Option<&Value> { + self.result.as_ref().and_then(|r| r.data.as_ref()) + } +} + +pub struct SharedState { + pub task_receiver: Arc>>, + pub result_sender: mpsc::Sender, + pub active_workers: Arc, + pub dashboard: Option>, +} + +impl SharedState { + pub fn increment_active_workers(&self) { + self.active_workers.fetch_add(1, Ordering::SeqCst); + } + + pub fn decrement_active_workers(&self) { + self.active_workers.fetch_sub(1, Ordering::SeqCst); + } +} + #[derive(Debug, Clone, Deserialize)] pub struct Config { #[serde(default = "default_max_workers")] @@ -51,7 +98,6 @@ fn default_initial_workers() -> usize { 2 } -// Stats for the execution #[derive(Debug, Serialize)] pub struct ExecutionStats { pub total_tasks: usize, @@ -60,7 +106,6 @@ pub struct ExecutionStats { pub execution_time_ms: u128, } -// Main response structure #[derive(Debug, Serialize)] pub struct ExecutionResponse { pub status: String, diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs new file mode 100644 index 000000000000..ebd12e7f11e3 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -0,0 +1,65 @@ +use std::collections::HashMap; + +use crate::agents::sub_recipe_execution_tool::types::{TaskInfo, TaskStatus}; + +pub fn get_task_name(task_info: &TaskInfo) -> &str { + if task_info.task.task_type == "sub_recipe" { + task_info + .task + .payload + .get("sub_recipe") + .and_then(|sr| sr.get("name")) + .and_then(|n| n.as_str()) + .unwrap_or(&task_info.task.id) + } else { + &task_info.task.id + } +} + +pub fn truncate_with_ellipsis(text: &str, max_len: usize) -> String { + if text.len() > max_len { + format!("{}...", &text[..max_len.saturating_sub(3)]) + } else { + text.to_string() + } +} + +pub fn count_by_status(tasks: &HashMap) -> (usize, usize, usize, usize, usize) { + let total = tasks.len(); + let (pending, running, completed, failed) = tasks.values().fold( + (0, 0, 0, 0), + |(pending, running, completed, failed), task| match task.status { + TaskStatus::Pending => (pending + 1, running, completed, failed), + TaskStatus::Running => (pending, running + 1, completed, failed), + TaskStatus::Completed => (pending, running, completed + 1, failed), + TaskStatus::Failed => (pending, running, completed, failed + 1), + }, + ); + (total, pending, running, completed, failed) +} + +pub fn strip_ansi_codes(text: &str) -> String { + let mut result = String::new(); + let mut chars = text.chars(); + + while let Some(ch) = chars.next() { + if ch == '\x1b' { + if chars.next() == Some('[') { + loop { + match chars.next() { + Some(c) if c.is_ascii_alphabetic() => break, + Some(_) => continue, + None => break, + } + } + } + } else { + result.push(ch); + } + } + + result +} + +#[cfg(test)] +mod tests; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs new file mode 100644 index 000000000000..a2e1d9b74a8f --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -0,0 +1,262 @@ +#[cfg(test)] +mod tests { + use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskStatus}; + use crate::agents::sub_recipe_execution_tool::utils::{ + count_by_status, get_task_name, strip_ansi_codes, truncate_with_ellipsis, + }; + use serde_json::json; + use std::collections::HashMap; + + mod truncate_with_ellipsis { + use super::*; + + #[test] + fn returns_original_when_under_limit() { + assert_eq!(truncate_with_ellipsis("hello", 10), "hello"); + assert_eq!(truncate_with_ellipsis("hi", 5), "hi"); + } + + #[test] + fn truncates_when_over_limit() { + assert_eq!(truncate_with_ellipsis("hello world", 5), "he..."); + assert_eq!( + truncate_with_ellipsis("very long text here", 10), + "very lo..." + ); + } + + #[test] + fn handles_empty_string() { + assert_eq!(truncate_with_ellipsis("", 5), ""); + } + + #[test] + fn handles_exact_limit() { + assert_eq!(truncate_with_ellipsis("hello", 5), "hello"); + } + + #[test] + fn handles_very_short_limit() { + assert_eq!(truncate_with_ellipsis("hello", 3), "..."); + assert_eq!(truncate_with_ellipsis("hi", 2), "hi"); // Under limit, return as-is + } + } + + mod strip_ansi_codes { + use super::*; + + #[test] + fn preserves_plain_text() { + assert_eq!(strip_ansi_codes("hello world"), "hello world"); + assert_eq!(strip_ansi_codes("no ansi codes"), "no ansi codes"); + } + + #[test] + fn removes_color_codes() { + assert_eq!(strip_ansi_codes("\x1b[31mred text\x1b[0m"), "red text"); + assert_eq!(strip_ansi_codes("\x1b[32mgreen\x1b[0m"), "green"); + } + + #[test] + fn removes_complex_formatting() { + assert_eq!( + strip_ansi_codes("\x1b[1;32mbold green\x1b[0m"), + "bold green" + ); + assert_eq!( + strip_ansi_codes("\x1b[4;31munderline red\x1b[0m"), + "underline red" + ); + } + + #[test] + fn handles_multiple_sequences() { + let input = "\x1b[31mred\x1b[0m normal \x1b[32mgreen\x1b[0m"; + assert_eq!(strip_ansi_codes(input), "red normal green"); + } + + #[test] + fn handles_empty_string() { + assert_eq!(strip_ansi_codes(""), ""); + } + + #[test] + fn handles_malformed_sequences() { + // Incomplete escape sequence - our function strips the \x1b char + assert_eq!(strip_ansi_codes("\x1b hello"), "hello"); + } + } + + mod get_task_name { + use super::*; + + #[test] + fn extracts_sub_recipe_name() { + let sub_recipe_task = Task { + id: "task_1".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({ + "sub_recipe": { + "name": "my_recipe", + "recipe_path": "/path/to/recipe" + } + }), + }; + + let task_info = TaskInfo { + task: sub_recipe_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "my_recipe"); + } + + #[test] + fn falls_back_to_task_id_for_text_instruction() { + let text_task = Task { + id: "task_2".to_string(), + task_type: "text_instruction".to_string(), + payload: json!({"text_instruction": "do something"}), + }; + + let task_info = TaskInfo { + task: text_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "task_2"); + } + + #[test] + fn falls_back_to_task_id_when_sub_recipe_name_missing() { + let malformed_task = Task { + id: "task_3".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({ + "sub_recipe": { + "recipe_path": "/path/to/recipe" + // missing "name" field + } + }), + }; + + let task_info = TaskInfo { + task: malformed_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "task_3"); + } + + #[test] + fn falls_back_to_task_id_when_sub_recipe_missing() { + let malformed_task = Task { + id: "task_4".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({}), // missing "sub_recipe" field + }; + + let task_info = TaskInfo { + task: malformed_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "task_4"); + } + } + + mod count_by_status { + use super::*; + + fn create_test_task(id: &str, status: TaskStatus) -> TaskInfo { + TaskInfo { + task: Task { + id: id.to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + } + } + + #[test] + fn counts_empty_map() { + let tasks = HashMap::new(); + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (0, 0, 0, 0, 0) + ); + } + + #[test] + fn counts_single_status() { + let mut tasks = HashMap::new(); + tasks.insert( + "task1".to_string(), + create_test_task("task1", TaskStatus::Pending), + ); + tasks.insert( + "task2".to_string(), + create_test_task("task2", TaskStatus::Pending), + ); + + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (2, 2, 0, 0, 0) + ); + } + + #[test] + fn counts_mixed_statuses() { + let mut tasks = HashMap::new(); + tasks.insert( + "task1".to_string(), + create_test_task("task1", TaskStatus::Pending), + ); + tasks.insert( + "task2".to_string(), + create_test_task("task2", TaskStatus::Running), + ); + tasks.insert( + "task3".to_string(), + create_test_task("task3", TaskStatus::Completed), + ); + tasks.insert( + "task4".to_string(), + create_test_task("task4", TaskStatus::Failed), + ); + tasks.insert( + "task5".to_string(), + create_test_task("task5", TaskStatus::Completed), + ); + + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (5, 1, 1, 2, 1) + ); + } + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index e48f19c4d360..d9f750480432 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -1,9 +1,7 @@ -use crate::agents::sub_recipe_execution_tool::tasks::process_task; -use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult}; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; +use crate::agents::sub_recipe_execution_tool::tasks::{process_task, process_task_with_dashboard}; +use crate::agents::sub_recipe_execution_tool::types::{SharedState, Task, TaskResult}; use std::sync::Arc; -use tokio::sync::mpsc; -use tokio::time::{sleep, Duration}; #[cfg(test)] mod tests { @@ -22,6 +20,7 @@ mod tests { active_workers: Arc::new(AtomicUsize::new(0)), should_stop: Arc::new(AtomicBool::new(false)), completed_tasks: Arc::new(AtomicUsize::new(0)), + dashboard: None, }); // Test that spawn_worker returns a JoinHandle @@ -40,21 +39,33 @@ mod tests { } } -pub struct SharedState { - pub task_receiver: Arc>>, - pub result_sender: mpsc::Sender, - pub active_workers: Arc, - pub should_stop: Arc, - pub completed_tasks: Arc, +async fn receive_task(state: &SharedState) -> Option { + let mut receiver = state.task_receiver.lock().await; + receiver.recv().await +} + +async fn execute_task( + task: Task, + timeout: u64, + dashboard: Option>, +) -> TaskResult { + if let Some(dashboard) = &dashboard { + dashboard.start_task(&task.id).await; + } + + if let Some(dashboard) = dashboard { + process_task_with_dashboard(&task, timeout, Some(dashboard)).await + } else { + process_task(&task, timeout).await + } } -// Spawn a worker task pub fn spawn_worker( state: Arc, worker_id: usize, timeout_seconds: u64, ) -> tokio::task::JoinHandle<()> { - state.active_workers.fetch_add(1, Ordering::SeqCst); + state.increment_active_workers(); tokio::spawn(async move { worker_loop(state, worker_id, timeout_seconds).await; @@ -62,72 +73,14 @@ pub fn spawn_worker( } async fn worker_loop(state: Arc, _worker_id: usize, timeout_seconds: u64) { - loop { - // Try to receive a task - let task = { - let mut receiver = state.task_receiver.lock().await; - receiver.recv().await - }; - - match task { - Some(task) => { - // Process the task - let result = process_task(&task, timeout_seconds).await; + while let Some(task) = receive_task(&state).await { + let result = execute_task(task, timeout_seconds, state.dashboard.clone()).await; - // Send result - let _ = state.result_sender.send(result).await; - - // Update completed count - state.completed_tasks.fetch_add(1, Ordering::SeqCst); - } - None => { - // Channel closed, exit worker - break; - } - } - - // Check if we should stop - if state.should_stop.load(Ordering::SeqCst) { + if let Err(e) = state.result_sender.send(result).await { + eprintln!("Worker failed to send result: {}", e); break; } } - // Worker is exiting - state.active_workers.fetch_sub(1, Ordering::SeqCst); -} - -// Scaling controller that monitors queue and spawns workers -pub async fn run_scaler( - state: Arc, - task_count: usize, - max_workers: usize, - timeout_seconds: u64, -) { - let mut worker_count = 0; - - loop { - sleep(Duration::from_millis(100)).await; - - let active = state.active_workers.load(Ordering::SeqCst); - let completed = state.completed_tasks.load(Ordering::SeqCst); - let pending = task_count.saturating_sub(completed); - - // Simple scaling logic: spawn worker if many pending tasks and under limit - if pending > active * 2 && active < max_workers && worker_count < max_workers { - let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds); - worker_count += 1; - } - - // If all tasks completed, signal stop - if completed >= task_count { - state.should_stop.store(true, Ordering::SeqCst); - break; - } - - // If no active workers and tasks remaining, spawn one - if active == 0 && pending > 0 { - let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds); - worker_count += 1; - } - } + state.decrement_active_workers(); } From 9ca6b25e6734d343b87c26b3827ab2c7ed7821bc Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 8 Jul 2025 17:21:20 +1000 Subject: [PATCH 05/43] better view --- .../sub_recipe_execution_tool/dashboard.rs | 47 +++++++++++++------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs index 68867bb8e5c8..c74cb7024187 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs @@ -13,6 +13,7 @@ pub struct TaskDashboard { tasks: Arc>>, last_display: Arc>, last_refresh: Arc>, + initial_display_shown: Arc>, } impl TaskDashboard { @@ -39,6 +40,7 @@ impl TaskDashboard { tasks: Arc::new(RwLock::new(task_map)), last_display: Arc::new(RwLock::new(String::new())), last_refresh: Arc::new(RwLock::new(Instant::now())), + initial_display_shown: Arc::new(RwLock::new(false)), } } @@ -94,20 +96,28 @@ impl TaskDashboard { let tasks = self.tasks.read().await; let mut display = String::new(); - // Clear screen and move to top - display.push_str("\x1b[2J\x1b[H"); - - // Title - display.push_str("🎯 Task Execution Dashboard\n"); - display.push_str("═══════════════════════════\n\n"); + let mut initial_shown = self.initial_display_shown.write().await; + if !*initial_shown { + // Clear screen and show header only on first display + display.push_str("\x1b[2J\x1b[H"); + display.push_str("🎯 Task Execution Dashboard\n"); + display.push_str("═══════════════════════════\n\n"); + *initial_shown = true; + } else { + // Move cursor to beginning of progress line (line 4) + display.push_str("\x1b[4;1H"); + } + drop(initial_shown); - // Summary stats + // Summary stats (this line gets updated in-place) let (total, pending, running, completed, failed) = count_by_status(&tasks); - - display.push_str(&format!("📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed\n\n", + display.push_str(&format!("📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed", total, pending, running, completed, failed)); - // Task list + // Clear to end of line and add newlines + display.push_str("\x1b[K\n\n"); + + // Task list (update in-place) let mut task_list: Vec<_> = tasks.values().collect(); task_list.sort_by_key(|t| &t.task.id); @@ -122,9 +132,10 @@ impl TaskDashboard { let task_name = get_task_name(task_info); display.push_str(&format!( - "{} {} ({})\n", + "{} {} ({})", status_icon, task_name, task_info.task.task_type )); + display.push_str("\x1b[K\n"); // Clear to end of line if let Some(start_time) = task_info.start_time { let duration = if let Some(end_time) = task_info.end_time { @@ -132,24 +143,30 @@ impl TaskDashboard { } else { Instant::now().duration_since(start_time) }; - display.push_str(&format!(" ⏱️ {:.1}s\n", duration.as_secs_f64())); + display.push_str(&format!(" ⏱️ {:.1}s", duration.as_secs_f64())); + display.push_str("\x1b[K\n"); // Clear to end of line } if matches!(task_info.status, TaskStatus::Running) && !task_info.current_output.is_empty() { let output_preview = truncate_with_ellipsis(&task_info.current_output, 100); - display.push_str(&format!(" 💬 {}\n", output_preview.replace('\n', " | "))); + display.push_str(&format!(" 💬 {}", output_preview.replace('\n', " | "))); + display.push_str("\x1b[K\n"); // Clear to end of line } if let Some(error) = task_info.error() { let error_preview = truncate_with_ellipsis(error, 80); - display.push_str(&format!(" ⚠️ {}\n", error_preview.replace('\n', " "))); + display.push_str(&format!(" ⚠️ {}", error_preview.replace('\n', " "))); + display.push_str("\x1b[K\n"); // Clear to end of line } - display.push('\n'); + display.push_str("\x1b[K\n"); // Clear to end of line and add blank line } + // Clear any remaining lines below + display.push_str("\x1b[J"); + // Only update display if it changed let mut last_display = self.last_display.write().await; if *last_display != display { From 744669280d5485060fd8f889b0413a8b326de8b1 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 8 Jul 2025 18:04:45 +1000 Subject: [PATCH 06/43] clean up dashboard.rs --- .../sub_recipe_execution_tool/dashboard.rs | 155 ++++++++---------- .../sub_recipe_execution_tool/utils/mod.rs | 117 +++++++++++++ 2 files changed, 183 insertions(+), 89 deletions(-) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs index c74cb7024187..ee8ae16bd4b5 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs @@ -2,13 +2,19 @@ use std::collections::HashMap; use std::io::{self, Write}; use std::sync::Arc; use tokio::sync::RwLock; -use tokio::time::{Duration, Instant}; +use tokio::time::{sleep, Duration, Instant}; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskResult, TaskStatus}; use crate::agents::sub_recipe_execution_tool::utils::{ - count_by_status, get_task_name, strip_ansi_codes, truncate_with_ellipsis, + count_by_status, format_task_display, get_task_name, process_output_lines, }; +const THROTTLE_INTERVAL_MS: u64 = 1000; +const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H"; +const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H"; +const CLEAR_TO_EOL: &str = "\x1b[K"; +const CLEAR_BELOW: &str = "\x1b[J"; + pub struct TaskDashboard { tasks: Arc>>, last_display: Arc>, @@ -68,128 +74,96 @@ impl TaskDashboard { pub async fn update_task_output(&self, task_id: &str, output: &str) { let mut tasks = self.tasks.write().await; if let Some(task_info) = tasks.get_mut(task_id) { - // Keep only the last few lines to avoid overwhelming display - let lines: Vec<&str> = output.lines().collect(); - let recent_lines = if lines.len() > 2 { - &lines[lines.len() - 2..] - } else { - &lines - }; - - // Strip ANSI escape sequences to prevent color flashing - let clean_output = recent_lines.join("\n"); - task_info.current_output = strip_ansi_codes(&clean_output); + task_info.current_output = process_output_lines(output); } drop(tasks); - // Throttle refreshes to avoid overwhelming the display (max 1 per second) + if !self.should_throttle_refresh().await { + self.refresh_display().await; + } + } + + async fn should_throttle_refresh(&self) -> bool { let now = Instant::now(); let mut last_refresh = self.last_refresh.write().await; - if now.duration_since(*last_refresh) > Duration::from_millis(1000) { + + if now.duration_since(*last_refresh) > Duration::from_millis(THROTTLE_INTERVAL_MS) { *last_refresh = now; - drop(last_refresh); - self.refresh_display().await; + false + } else { + true } } - pub async fn refresh_display(&self) { - let tasks = self.tasks.read().await; - let mut display = String::new(); - - let mut initial_shown = self.initial_display_shown.write().await; + fn render_header(&self, display: &mut String, initial_shown: &mut bool) { if !*initial_shown { - // Clear screen and show header only on first display - display.push_str("\x1b[2J\x1b[H"); + display.push_str(CLEAR_SCREEN); display.push_str("🎯 Task Execution Dashboard\n"); display.push_str("═══════════════════════════\n\n"); *initial_shown = true; } else { - // Move cursor to beginning of progress line (line 4) - display.push_str("\x1b[4;1H"); + display.push_str(MOVE_TO_PROGRESS_LINE); } - drop(initial_shown); + } - // Summary stats (this line gets updated in-place) - let (total, pending, running, completed, failed) = count_by_status(&tasks); - display.push_str(&format!("📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed", - total, pending, running, completed, failed)); + fn render_progress_line(&self, display: &mut String, tasks: &HashMap) { + let (total, pending, running, completed, failed) = count_by_status(tasks); + display.push_str(&format!( + "📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed", + total, pending, running, completed, failed + )); + display.push_str(&format!("{}\n\n", CLEAR_TO_EOL)); + } - // Clear to end of line and add newlines - display.push_str("\x1b[K\n\n"); + fn render_task(&self, display: &mut String, task_info: &TaskInfo) { + let task_display = format_task_display(task_info, Instant::now()); + display.push_str(&task_display); + } - // Task list (update in-place) - let mut task_list: Vec<_> = tasks.values().collect(); - task_list.sort_by_key(|t| &t.task.id); + async fn update_display_if_changed(&self, display: String) { + let mut last_display = self.last_display.write().await; + if *last_display != display { + print!("{}", display); + io::stdout().flush().unwrap(); + *last_display = display; + } + } - for task_info in task_list { - let status_icon = match task_info.status { - TaskStatus::Pending => "⏳", - TaskStatus::Running => "🏃", - TaskStatus::Completed => "✅", - TaskStatus::Failed => "❌", - }; - - let task_name = get_task_name(task_info); - - display.push_str(&format!( - "{} {} ({})", - status_icon, task_name, task_info.task.task_type - )); - display.push_str("\x1b[K\n"); // Clear to end of line - - if let Some(start_time) = task_info.start_time { - let duration = if let Some(end_time) = task_info.end_time { - end_time.duration_since(start_time) - } else { - Instant::now().duration_since(start_time) - }; - display.push_str(&format!(" ⏱️ {:.1}s", duration.as_secs_f64())); - display.push_str("\x1b[K\n"); // Clear to end of line - } + pub async fn refresh_display(&self) { + let tasks = self.tasks.read().await; + let mut display = String::new(); - if matches!(task_info.status, TaskStatus::Running) - && !task_info.current_output.is_empty() - { - let output_preview = truncate_with_ellipsis(&task_info.current_output, 100); - display.push_str(&format!(" 💬 {}", output_preview.replace('\n', " | "))); - display.push_str("\x1b[K\n"); // Clear to end of line - } + let mut initial_shown = self.initial_display_shown.write().await; + self.render_header(&mut display, &mut initial_shown); + drop(initial_shown); - if let Some(error) = task_info.error() { - let error_preview = truncate_with_ellipsis(error, 80); - display.push_str(&format!(" ⚠️ {}", error_preview.replace('\n', " "))); - display.push_str("\x1b[K\n"); // Clear to end of line - } + self.render_progress_line(&mut display, &tasks); - display.push_str("\x1b[K\n"); // Clear to end of line and add blank line + let mut task_list: Vec<_> = tasks.values().collect(); + task_list.sort_by_key(|t| &t.task.id); + + for task_info in task_list { + self.render_task(&mut display, task_info); } - // Clear any remaining lines below - display.push_str("\x1b[J"); + display.push_str(CLEAR_BELOW); - // Only update display if it changed - let mut last_display = self.last_display.write().await; - if *last_display != display { - print!("{}", display); - io::stdout().flush().unwrap(); - *last_display = display; - } + self.update_display_if_changed(display).await; } pub async fn show_final_summary(&self) { let tasks = self.tasks.read().await; - println!("\n🎉 Execution Complete!"); + println!("Execution Complete!"); println!("═══════════════════════"); let (total, _, _, completed, failed) = count_by_status(&tasks); - println!("📊 Final Results:"); - println!(" Total Tasks: {}", total); - println!(" ✅ Completed: {}", completed); - println!(" ❌ Failed: {}", failed); + println!("Total Tasks: {}", total); + println!("✅ Completed: {}", completed); + println!("❌ Failed: {}", failed); println!( - " 📈 Success Rate: {:.1}%", + "📈 Success Rate: {:.1}%", (completed as f64 / total as f64) * 100.0 ); @@ -205,5 +179,8 @@ impl TaskDashboard { } } } + + println!("\n📝 Generating summary..."); + sleep(Duration::from_millis(500)).await; } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index ebd12e7f11e3..31f20273ec9e 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -1,7 +1,14 @@ use std::collections::HashMap; +use tokio::time::Instant; use crate::agents::sub_recipe_execution_tool::types::{TaskInfo, TaskStatus}; +// Constants for display formatting +const MAX_OUTPUT_LINES: usize = 2; +const OUTPUT_PREVIEW_LENGTH: usize = 100; +const ERROR_PREVIEW_LENGTH: usize = 80; +const CLEAR_TO_EOL: &str = "\x1b[K"; + pub fn get_task_name(task_info: &TaskInfo) -> &str { if task_info.task.task_type == "sub_recipe" { task_info @@ -61,5 +68,115 @@ pub fn strip_ansi_codes(text: &str) -> String { result } +// Pure utility functions for dashboard rendering + +/// Get status icon for a given task status +pub fn get_status_icon(status: &TaskStatus) -> &'static str { + match status { + TaskStatus::Pending => "⏳", + TaskStatus::Running => "🏃", + TaskStatus::Completed => "✅", + TaskStatus::Failed => "❌", + } +} + +/// Process output lines, keeping only recent lines and stripping ANSI codes +pub fn process_output_lines(output: &str) -> String { + let lines: Vec<&str> = output.lines().collect(); + let recent_lines = if lines.len() > MAX_OUTPUT_LINES { + &lines[lines.len() - MAX_OUTPUT_LINES..] + } else { + &lines + }; + + let clean_output = recent_lines.join("\n"); + strip_ansi_codes(&clean_output) +} + +/// Format task timing information +pub fn format_task_timing(task_info: &TaskInfo, current_time: Instant) -> Option { + task_info.start_time.map(|start_time| { + let duration = if let Some(end_time) = task_info.end_time { + end_time.duration_since(start_time) + } else { + current_time.duration_since(start_time) + }; + format!( + " ⏱️ {:.1}s{} +", + duration.as_secs_f64(), + CLEAR_TO_EOL + ) + }) +} + +/// Format task output preview +pub fn format_task_output(task_info: &TaskInfo) -> Option { + if matches!(task_info.status, TaskStatus::Running) && !task_info.current_output.is_empty() { + let output_preview = + truncate_with_ellipsis(&task_info.current_output, OUTPUT_PREVIEW_LENGTH); + Some(format!( + " 💬 {}{} +", + output_preview.replace('\n', " | "), + CLEAR_TO_EOL + )) + } else { + None + } +} + +/// Format task error information +pub fn format_task_error(task_info: &TaskInfo) -> Option { + task_info.error().map(|error| { + let error_preview = truncate_with_ellipsis(error, ERROR_PREVIEW_LENGTH); + format!( + " ⚠️ {}{} +", + error_preview.replace('\n', " "), + CLEAR_TO_EOL + ) + }) +} + +/// Format complete task display +pub fn format_task_display(task_info: &TaskInfo, current_time: Instant) -> String { + let mut display = String::new(); + + let status_icon = get_status_icon(&task_info.status); + let task_name = get_task_name(task_info); + + // Task status line + display.push_str(&format!( + "{} {} ({}){} +", + status_icon, task_name, task_info.task.task_type, CLEAR_TO_EOL + )); + + // Task timing + if let Some(timing) = format_task_timing(task_info, current_time) { + display.push_str(&timing); + } + + // Task output (if running) + if let Some(output) = format_task_output(task_info) { + display.push_str(&output); + } + + // Task error (if failed) + if let Some(error) = format_task_error(task_info) { + display.push_str(&error); + } + + // Empty line + display.push_str(&format!( + "{} +", + CLEAR_TO_EOL + )); + + display +} + #[cfg(test)] mod tests; From b84840913bbc92eb29c729fe3723b2f86450ba6b Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 8 Jul 2025 18:13:29 +1000 Subject: [PATCH 07/43] fixed tests --- .../sub_recipe_execution_tool/utils/tests.rs | 381 +++++++++++++++++- 1 file changed, 369 insertions(+), 12 deletions(-) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs index a2e1d9b74a8f..b0d3e5e0dac6 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -1,14 +1,15 @@ -#[cfg(test)] -mod tests { - use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskStatus}; - use crate::agents::sub_recipe_execution_tool::utils::{ - count_by_status, get_task_name, strip_ansi_codes, truncate_with_ellipsis, - }; - use serde_json::json; - use std::collections::HashMap; - - mod truncate_with_ellipsis { - use super::*; +use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskStatus}; +use crate::agents::sub_recipe_execution_tool::utils::{ + count_by_status, format_task_display, format_task_error, format_task_output, + format_task_timing, get_status_icon, get_task_name, process_output_lines, strip_ansi_codes, + truncate_with_ellipsis, +}; +use serde_json::json; +use std::collections::HashMap; +use tokio::time::Instant; + +mod truncate_with_ellipsis { + use super::*; #[test] fn returns_original_when_under_limit() { @@ -259,4 +260,360 @@ mod tests { ); } } -} + + mod get_status_icon { + use super::*; + + #[test] + fn returns_correct_icon_for_pending() { + assert_eq!(get_status_icon(&TaskStatus::Pending), "⏳"); + } + + #[test] + fn returns_correct_icon_for_running() { + assert_eq!(get_status_icon(&TaskStatus::Running), "🏃"); + } + + #[test] + fn returns_correct_icon_for_completed() { + assert_eq!(get_status_icon(&TaskStatus::Completed), "✅"); + } + + #[test] + fn returns_correct_icon_for_failed() { + assert_eq!(get_status_icon(&TaskStatus::Failed), "❌"); + } + } + + mod process_output_lines { + use super::*; + + #[test] + fn preserves_short_output() { + let output = "line 1\nline 2"; + assert_eq!(process_output_lines(output), "line 1\nline 2"); + } + + #[test] + fn keeps_only_recent_lines_when_too_many() { + let output = "line 1\nline 2\nline 3\nline 4\nline 5"; + let result = process_output_lines(output); + assert_eq!(result, "line 4\nline 5"); + } + + #[test] + fn strips_ansi_codes_from_output() { + let output = "\x1b[31mred line 1\x1b[0m\n\x1b[32mgreen line 2\x1b[0m"; + let result = process_output_lines(output); + assert_eq!(result, "red line 1\ngreen line 2"); + } + + #[test] + fn handles_empty_output() { + assert_eq!(process_output_lines(""), ""); + } + + #[test] + fn handles_single_line() { + assert_eq!(process_output_lines("single line"), "single line"); + } + + #[test] + fn combines_ansi_stripping_and_line_limiting() { + let output = "\x1b[31mline 1\x1b[0m\n\x1b[32mline 2\x1b[0m\n\x1b[33mline 3\x1b[0m\n\x1b[34mline 4\x1b[0m"; + let result = process_output_lines(output); + assert_eq!(result, "line 3\nline 4"); + } + } + + mod format_task_timing { + use super::*; + use std::time::Duration; + + fn create_test_task_info_with_timing( + start: Option, + end: Option, + ) -> TaskInfo { + TaskInfo { + task: Task { + id: "test_task".to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status: TaskStatus::Running, + start_time: start, + end_time: end, + result: None, + current_output: String::new(), + } + } + + #[test] + fn returns_none_when_no_start_time() { + let task_info = create_test_task_info_with_timing(None, None); + let current_time = Instant::now(); + assert!(format_task_timing(&task_info, current_time).is_none()); + } + + #[test] + fn formats_running_task_duration() { + let start_time = Instant::now(); + let current_time = start_time + Duration::from_millis(1500); + let task_info = create_test_task_info_with_timing(Some(start_time), None); + + let result = format_task_timing(&task_info, current_time).unwrap(); + assert!(result.contains("1.5s")); + assert!(result.contains("⏱️")); + } + + #[test] + fn formats_completed_task_duration() { + let start_time = Instant::now(); + let end_time = start_time + Duration::from_millis(2500); + let current_time = Instant::now(); // This shouldn't matter for completed tasks + let task_info = create_test_task_info_with_timing(Some(start_time), Some(end_time)); + + let result = format_task_timing(&task_info, current_time).unwrap(); + assert!(result.contains("2.5s")); + assert!(result.contains("⏱️")); + } + } + + mod format_task_output { + use super::*; + + fn create_test_task_info_with_output(status: TaskStatus, output: &str) -> TaskInfo { + TaskInfo { + task: Task { + id: "test_task".to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status, + start_time: None, + end_time: None, + result: None, + current_output: output.to_string(), + } + } + + #[test] + fn returns_none_for_non_running_tasks() { + let task_info = create_test_task_info_with_output(TaskStatus::Pending, "some output"); + assert!(format_task_output(&task_info).is_none()); + + let task_info = create_test_task_info_with_output(TaskStatus::Completed, "some output"); + assert!(format_task_output(&task_info).is_none()); + + let task_info = create_test_task_info_with_output(TaskStatus::Failed, "some output"); + assert!(format_task_output(&task_info).is_none()); + } + + #[test] + fn returns_none_for_running_task_with_empty_output() { + let task_info = create_test_task_info_with_output(TaskStatus::Running, ""); + assert!(format_task_output(&task_info).is_none()); + } + + #[test] + fn formats_running_task_with_output() { + let task_info = create_test_task_info_with_output(TaskStatus::Running, "Building project..."); + let result = format_task_output(&task_info).unwrap(); + + assert!(result.contains("💬")); + assert!(result.contains("Building project...")); + } + + #[test] + fn replaces_newlines_with_pipes() { + let task_info = create_test_task_info_with_output(TaskStatus::Running, "line 1\nline 2\nline 3"); + let result = format_task_output(&task_info).unwrap(); + + assert!(result.contains("line 1 | line 2 | line 3")); + } + + #[test] + fn truncates_long_output() { + let long_output = "a".repeat(150); + let task_info = create_test_task_info_with_output(TaskStatus::Running, &long_output); + let result = format_task_output(&task_info).unwrap(); + + assert!(result.contains("...")); + assert!(result.len() < long_output.len() + 20); // Account for formatting + } + } + + mod format_task_error { + use super::*; + use crate::agents::sub_recipe_execution_tool::types::{TaskResult, TaskStatus}; + + fn create_test_task_info_with_error(error_msg: Option<&str>) -> TaskInfo { + let result = error_msg.map(|msg| TaskResult { + task_id: "test_task".to_string(), + status: TaskStatus::Failed, + data: None, + error: Some(msg.to_string()), + }); + + TaskInfo { + task: Task { + id: "test_task".to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status: TaskStatus::Failed, + start_time: None, + end_time: None, + result, + current_output: String::new(), + } + } + + #[test] + fn returns_none_when_no_error() { + let task_info = create_test_task_info_with_error(None); + assert!(format_task_error(&task_info).is_none()); + } + + #[test] + fn formats_error_message() { + let task_info = create_test_task_info_with_error(Some("File not found")); + let result = format_task_error(&task_info).unwrap(); + + assert!(result.contains("⚠️")); + assert!(result.contains("File not found")); + } + + #[test] + fn replaces_newlines_in_error() { + let task_info = create_test_task_info_with_error(Some("Error on line 1\nError on line 2")); + let result = format_task_error(&task_info).unwrap(); + + assert!(result.contains("Error on line 1 Error on line 2")); + } + + #[test] + fn truncates_long_error() { + let long_error = "error ".repeat(30); + let task_info = create_test_task_info_with_error(Some(&long_error)); + let result = format_task_error(&task_info).unwrap(); + + assert!(result.contains("...")); + assert!(result.len() < long_error.len() + 20); // Account for formatting + } + } + + mod format_task_display { + use super::*; + use std::time::Duration; + + fn create_comprehensive_task_info( + task_name: &str, + status: TaskStatus, + start_time: Option, + end_time: Option, + current_output: &str, + error: Option<&str>, + ) -> TaskInfo { + let result = error.map(|msg| crate::agents::sub_recipe_execution_tool::types::TaskResult { + task_id: task_name.to_string(), + status: status.clone(), + data: None, + error: Some(msg.to_string()), + }); + + TaskInfo { + task: Task { + id: task_name.to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status, + start_time, + end_time, + result, + current_output: current_output.to_string(), + } + } + + #[test] + fn formats_pending_task() { + let task_info = create_comprehensive_task_info( + "pending_task", + TaskStatus::Pending, + None, + None, + "", + None, + ); + let current_time = Instant::now(); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("⏳")); + assert!(result.contains("pending_task")); + assert!(result.contains("(test)")); + } + + #[test] + fn formats_running_task_with_output() { + let start_time = Instant::now(); + let current_time = start_time + Duration::from_secs(2); + let task_info = create_comprehensive_task_info( + "running_task", + TaskStatus::Running, + Some(start_time), + None, + "Compiling...", + None, + ); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("🏃")); + assert!(result.contains("running_task")); + assert!(result.contains("2.0s")); + assert!(result.contains("💬")); + assert!(result.contains("Compiling...")); + } + + #[test] + fn formats_failed_task_with_error() { + let start_time = Instant::now(); + let end_time = start_time + Duration::from_millis(1500); + let task_info = create_comprehensive_task_info( + "failed_task", + TaskStatus::Failed, + Some(start_time), + Some(end_time), + "", + Some("Compilation failed"), + ); + let current_time = Instant::now(); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("❌")); + assert!(result.contains("failed_task")); + assert!(result.contains("1.5s")); + assert!(result.contains("⚠️")); + assert!(result.contains("Compilation failed")); + } + + #[test] + fn formats_completed_task() { + let start_time = Instant::now(); + let end_time = start_time + Duration::from_secs(3); + let task_info = create_comprehensive_task_info( + "completed_task", + TaskStatus::Completed, + Some(start_time), + Some(end_time), + "", + None, + ); + let current_time = Instant::now(); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("✅")); + assert!(result.contains("completed_task")); + assert!(result.contains("3.0s")); + } + } From a02c0560dbc5f02dfd9bd0bb0e53d92d77f466ff Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 8 Jul 2025 19:39:58 +1000 Subject: [PATCH 08/43] fixed scenario that only run params are specified --- crates/goose/src/agents/recipe_tools/mod.rs | 1 + .../agents/recipe_tools/param_utils/mod.rs | 116 ++ .../agents/recipe_tools/param_utils/tests.rs | 236 ++++ .../agents/recipe_tools/sub_recipe_tools.rs | 72 +- .../recipe_tools/sub_recipe_tools/tests.rs | 174 --- .../sub_recipe_execution_tool/utils/tests.rs | 1065 +++++++++-------- .../sub_recipe_execution_tool/workers.rs | 7 +- 7 files changed, 893 insertions(+), 778 deletions(-) create mode 100644 crates/goose/src/agents/recipe_tools/param_utils/mod.rs create mode 100644 crates/goose/src/agents/recipe_tools/param_utils/tests.rs diff --git a/crates/goose/src/agents/recipe_tools/mod.rs b/crates/goose/src/agents/recipe_tools/mod.rs index 5f2f95fc8485..90603c88488e 100644 --- a/crates/goose/src/agents/recipe_tools/mod.rs +++ b/crates/goose/src/agents/recipe_tools/mod.rs @@ -1 +1,2 @@ +pub mod param_utils; pub mod sub_recipe_tools; diff --git a/crates/goose/src/agents/recipe_tools/param_utils/mod.rs b/crates/goose/src/agents/recipe_tools/param_utils/mod.rs new file mode 100644 index 000000000000..08db06c5e25e --- /dev/null +++ b/crates/goose/src/agents/recipe_tools/param_utils/mod.rs @@ -0,0 +1,116 @@ +use anyhow::Result; +use serde_json::Value; +use std::collections::HashMap; + +use crate::recipe::SubRecipe; + +pub fn extract_run_params( + sub_recipe: &SubRecipe, +) -> (HashMap, Vec>) { + let base_params = sub_recipe.values.clone().unwrap_or_default(); + + let run_params = sub_recipe + .executions + .as_ref() + .and_then(|e| e.runs.as_ref()) + .map(|runs| { + runs.iter() + .map(|run| { + let mut params = base_params.clone(); + if let Some(run_values) = &run.values { + params.extend(run_values.clone()); + } + params + }) + .collect::>() + }) + .unwrap_or_default(); + (base_params, run_params) +} + +pub fn validate_param_counts( + run_params: &[HashMap], + params_from_tool_call: &[Value], +) -> Result<()> { + if !run_params.is_empty() + && run_params.len() != params_from_tool_call.len() + && params_from_tool_call.len() > 1 + { + return Err(anyhow::anyhow!( + "The number of runs in the sub recipe ({}) does not match the number of task parameters ({})", + run_params.len(), + params_from_tool_call.len() + )); + } + Ok(()) +} + +pub fn prepare_base_params( + base_params: &HashMap, + run_params: &[HashMap], + params_from_tool_call: &[Value], +) -> Vec> { + if run_params.is_empty() { + vec![base_params.clone(); params_from_tool_call.len()] + } else { + run_params.to_vec() + } +} + +pub fn prepare_tool_params( + params_from_tool_call: &[Value], + run_params: &[HashMap], +) -> Vec { + if params_from_tool_call.len() == 1 && run_params.len() > 1 { + vec![params_from_tool_call[0].clone(); run_params.len()] + } else { + params_from_tool_call.to_vec() + } +} + +pub fn merge_parameters( + tool_params: Vec, + base_params: Vec>, +) -> Vec> { + tool_params + .into_iter() + .zip(base_params) + .map(|(tool_param, mut base_param_map)| { + if let Some(param_obj) = tool_param.as_object() { + for (key, value) in param_obj { + let value_str = value + .as_str() + .map(String::from) + .unwrap_or_else(|| value.to_string()); + base_param_map.entry(key.clone()).or_insert(value_str); + } + } + base_param_map + }) + .collect() +} + +pub fn prepare_command_params( + sub_recipe: &SubRecipe, + params_from_tool_call: Vec, +) -> Result>> { + let (base_params, run_params) = extract_run_params(sub_recipe); + + if params_from_tool_call.is_empty() { + return Ok(run_params); + } + + validate_param_counts(&run_params, ¶ms_from_tool_call)?; + + let base_params_for_merging = + prepare_base_params(&base_params, &run_params, ¶ms_from_tool_call); + let tool_params_for_merging = prepare_tool_params(¶ms_from_tool_call, &run_params); + + Ok(merge_parameters( + tool_params_for_merging, + base_params_for_merging, + )) +} + +#[cfg(test)] +mod tests; diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs new file mode 100644 index 000000000000..81e5c412a8c7 --- /dev/null +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -0,0 +1,236 @@ +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::recipe::{Execution, ExecutionRun, SubRecipe}; + use serde_json::json; + use serde_json::Value; + + use crate::agents::recipe_tools::param_utils::prepare_command_params; + + fn setup_default_sub_recipe() -> SubRecipe { + let sub_recipe = SubRecipe { + name: "test_sub_recipe".to_string(), + path: "test_sub_recipe.yaml".to_string(), + values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), + executions: None, + }; + sub_recipe + } + + fn create_execution_values(key: &str, values: Vec) -> Execution { + let runs = values + .iter() + .map(|value| ExecutionRun { + values: Some(HashMap::from([(key.to_string(), value.to_string())])), + }) + .collect(); + Execution { + parallel: true, + runs: Some(runs), + } + } + + mod prepare_command_params_tests { + use super::*; + + mod without_execution_runs { + use super::*; + + #[test] + fn test_return_command_param() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "value2".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_command_param_when_value_override_passed_param_value() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "different_value".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_empty_command_param() { + let parameter_array = vec![]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!(result.len(), 0); + } + } + + mod with_execution_runs { + use super::*; + + #[test] + fn test_return_command_param() { + let parameter_array = vec![json!(HashMap::from([( + "key3".to_string(), + "value3".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = + Some(create_execution_values("key2", vec!["value2".to_string()])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ("key3".to_string(), "value3".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_command_param_when_all_values_from_tool_call_parameters() { + let parameter_array = vec![ + json!(HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()) + ])), + json!(HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()) + ])), + ]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + sub_recipe.executions = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ]), + ], + result + ); + } + + #[test] + fn test_return_command_param_when_all_from_values_in_sub_recipe() { + let parameter_array = vec![]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string(), "key2_value2".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ("key3".to_string(), "value3".to_string()), + ]) + ], + result + ); + } + + #[test] + fn test_return_command_param_when_tool_call_parameters_has_one_item_and_execution_runs_has_multiple_items( + ) { + let parameter_array = vec![json!(HashMap::from([( + "key3".to_string(), + "value3".to_string() + ),]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string(), "key2_value2".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ("key3".to_string(), "value3".to_string()), + ]) + ], + result + ); + } + + #[test] + fn test_throw_error_when_execution_runs_value_length_not_match_with_tool_call_parameters( + ) { + let parameter_array = vec![ + json!(HashMap::from([("key3".to_string(), "value3".to_string())])), + json!(HashMap::from([("key4".to_string(), "value4".to_string())])), + ]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array); + + assert!(result.is_err()); + } + } + } +} diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index e50ee588700f..8cd1a8599037 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -1,5 +1,5 @@ use std::collections::HashSet; -use std::{collections::HashMap, fs}; +use std::fs; use anyhow::Result; use mcp_core::tool::{Tool, ToolAnnotations}; @@ -8,6 +8,8 @@ use serde_json::{json, Map, Value}; use crate::agents::sub_recipe_execution_tool::lib::Task; use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubRecipe}; +use super::param_utils::prepare_command_params; + pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { @@ -155,73 +157,5 @@ fn get_input_schema(sub_recipe: &SubRecipe) -> Result { Ok(create_input_schema(param_properties, param_required)) } -fn extract_run_params( - sub_recipe: &SubRecipe, -) -> (HashMap, Vec>) { - let base_params = sub_recipe.values.clone().unwrap_or_default(); - - let run_params = sub_recipe - .executions - .as_ref() - .and_then(|e| e.runs.as_ref()) - .map(|runs| { - runs.iter() - .map(|run| { - let mut params = base_params.clone(); - if let Some(run_values) = &run.values { - params.extend(run_values.clone()); - } - params - }) - .collect::>() - }) - .unwrap_or_default(); - (base_params, run_params) -} - -fn prepare_command_params( - sub_recipe: &SubRecipe, - params_from_tool_call: Vec, -) -> Result>> { - let (base_params, run_params) = extract_run_params(sub_recipe); - - if params_from_tool_call.is_empty() { - return Ok(run_params); - } - - if !run_params.is_empty() && run_params.len() != params_from_tool_call.len() { - return Err(anyhow::anyhow!( - "The number of runs in the sub recipe ({}) does not match the number of task parameters ({})", - run_params.len(), - params_from_tool_call.len() - )); - } - - let run_params_for_merging = if run_params.is_empty() { - vec![base_params; params_from_tool_call.len()] - } else { - run_params - }; - - let merged_params = params_from_tool_call - .into_iter() - .zip(run_params_for_merging) - .map(|(tool_param, mut run_param_map)| { - if let Some(param_obj) = tool_param.as_object() { - for (key, value) in param_obj { - let value_str = value - .as_str() - .map(String::from) - .unwrap_or_else(|| value.to_string()); - run_param_map.entry(key.clone()).or_insert(value_str); - } - } - run_param_map - }) - .collect(); - - Ok(merged_params) -} - #[cfg(test)] mod tests; diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index fac6a095bbfe..a4956f65edda 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -30,180 +30,6 @@ mod tests { } } - mod prepare_command_params_tests { - use super::*; - - use crate::agents::recipe_tools::sub_recipe_tools::{ - prepare_command_params, tests::tests::setup_default_sub_recipe, - }; - - mod without_execution_runs { - use super::*; - - #[test] - fn test_return_command_param() { - let parameter_array = vec![json!(HashMap::from([( - "key2".to_string(), - "value2".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()) - ]),], - result - ); - } - - #[test] - fn test_return_command_param_when_value_override_passed_param_value() { - let parameter_array = vec![json!(HashMap::from([( - "key2".to_string(), - "different_value".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = Some(HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()), - ])); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()) - ]),], - result - ); - } - - #[test] - fn test_return_empty_command_param() { - let parameter_array = vec![]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = None; - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!(result.len(), 0); - } - } - - mod with_execution_runs { - use super::*; - - #[test] - fn test_return_command_param() { - let parameter_array = vec![json!(HashMap::from([( - "key3".to_string(), - "value3".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - sub_recipe.executions = - Some(create_execution_values("key2", vec!["value2".to_string()])); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()), - ("key3".to_string(), "value3".to_string()) - ]),], - result - ); - } - - #[test] - fn test_return_command_param_when_all_values_from_tool_call_parameters() { - let parameter_array = vec![ - json!(HashMap::from([ - ("key1".to_string(), "key1_value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()) - ])), - json!(HashMap::from([ - ("key1".to_string(), "key1_value2".to_string()), - ("key2".to_string(), "key2_value2".to_string()) - ])), - ]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = None; - sub_recipe.executions = None; - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![ - HashMap::from([ - ("key1".to_string(), "key1_value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()), - ]), - HashMap::from([ - ("key1".to_string(), "key1_value2".to_string()), - ("key2".to_string(), "key2_value2".to_string()), - ]), - ], - result - ); - } - - #[test] - fn test_return_command_param_when_all_from_values_in_sub_recipe() { - let parameter_array = vec![]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = Some(HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key3".to_string(), "value3".to_string()), - ])); - sub_recipe.executions = Some(create_execution_values( - "key2", - vec!["key2_value1".to_string(), "key2_value2".to_string()], - )); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![ - HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()), - ("key3".to_string(), "value3".to_string()), - ]), - HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value2".to_string()), - ("key3".to_string(), "value3".to_string()), - ]) - ], - result - ); - } - - #[test] - fn test_throw_error_when_execution_runs_value_length_not_match_with_tool_call_parameters( - ) { - let parameter_array = vec![json!(HashMap::from([( - "key3".to_string(), - "value3".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - sub_recipe.executions = Some(create_execution_values( - "key2", - vec!["key2_value1".to_string(), "key2_value2".to_string()], - )); - - let result = prepare_command_params(&sub_recipe, parameter_array); - - assert!(result.is_err()); - } - } - } - mod get_input_schema { use super::*; use crate::agents::recipe_tools::sub_recipe_tools::get_input_schema; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs index b0d3e5e0dac6..88d823092c34 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -11,609 +11,610 @@ use tokio::time::Instant; mod truncate_with_ellipsis { use super::*; - #[test] - fn returns_original_when_under_limit() { - assert_eq!(truncate_with_ellipsis("hello", 10), "hello"); - assert_eq!(truncate_with_ellipsis("hi", 5), "hi"); - } + #[test] + fn returns_original_when_under_limit() { + assert_eq!(truncate_with_ellipsis("hello", 10), "hello"); + assert_eq!(truncate_with_ellipsis("hi", 5), "hi"); + } - #[test] - fn truncates_when_over_limit() { - assert_eq!(truncate_with_ellipsis("hello world", 5), "he..."); - assert_eq!( - truncate_with_ellipsis("very long text here", 10), - "very lo..." - ); - } + #[test] + fn truncates_when_over_limit() { + assert_eq!(truncate_with_ellipsis("hello world", 5), "he..."); + assert_eq!( + truncate_with_ellipsis("very long text here", 10), + "very lo..." + ); + } - #[test] - fn handles_empty_string() { - assert_eq!(truncate_with_ellipsis("", 5), ""); - } + #[test] + fn handles_empty_string() { + assert_eq!(truncate_with_ellipsis("", 5), ""); + } - #[test] - fn handles_exact_limit() { - assert_eq!(truncate_with_ellipsis("hello", 5), "hello"); - } + #[test] + fn handles_exact_limit() { + assert_eq!(truncate_with_ellipsis("hello", 5), "hello"); + } - #[test] - fn handles_very_short_limit() { - assert_eq!(truncate_with_ellipsis("hello", 3), "..."); - assert_eq!(truncate_with_ellipsis("hi", 2), "hi"); // Under limit, return as-is - } + #[test] + fn handles_very_short_limit() { + assert_eq!(truncate_with_ellipsis("hello", 3), "..."); + assert_eq!(truncate_with_ellipsis("hi", 2), "hi"); // Under limit, return as-is } +} - mod strip_ansi_codes { - use super::*; +mod strip_ansi_codes { + use super::*; - #[test] - fn preserves_plain_text() { - assert_eq!(strip_ansi_codes("hello world"), "hello world"); - assert_eq!(strip_ansi_codes("no ansi codes"), "no ansi codes"); - } + #[test] + fn preserves_plain_text() { + assert_eq!(strip_ansi_codes("hello world"), "hello world"); + assert_eq!(strip_ansi_codes("no ansi codes"), "no ansi codes"); + } - #[test] - fn removes_color_codes() { - assert_eq!(strip_ansi_codes("\x1b[31mred text\x1b[0m"), "red text"); - assert_eq!(strip_ansi_codes("\x1b[32mgreen\x1b[0m"), "green"); - } + #[test] + fn removes_color_codes() { + assert_eq!(strip_ansi_codes("\x1b[31mred text\x1b[0m"), "red text"); + assert_eq!(strip_ansi_codes("\x1b[32mgreen\x1b[0m"), "green"); + } - #[test] - fn removes_complex_formatting() { - assert_eq!( - strip_ansi_codes("\x1b[1;32mbold green\x1b[0m"), - "bold green" - ); - assert_eq!( - strip_ansi_codes("\x1b[4;31munderline red\x1b[0m"), - "underline red" - ); - } + #[test] + fn removes_complex_formatting() { + assert_eq!( + strip_ansi_codes("\x1b[1;32mbold green\x1b[0m"), + "bold green" + ); + assert_eq!( + strip_ansi_codes("\x1b[4;31munderline red\x1b[0m"), + "underline red" + ); + } - #[test] - fn handles_multiple_sequences() { - let input = "\x1b[31mred\x1b[0m normal \x1b[32mgreen\x1b[0m"; - assert_eq!(strip_ansi_codes(input), "red normal green"); - } + #[test] + fn handles_multiple_sequences() { + let input = "\x1b[31mred\x1b[0m normal \x1b[32mgreen\x1b[0m"; + assert_eq!(strip_ansi_codes(input), "red normal green"); + } - #[test] - fn handles_empty_string() { - assert_eq!(strip_ansi_codes(""), ""); - } + #[test] + fn handles_empty_string() { + assert_eq!(strip_ansi_codes(""), ""); + } - #[test] - fn handles_malformed_sequences() { - // Incomplete escape sequence - our function strips the \x1b char - assert_eq!(strip_ansi_codes("\x1b hello"), "hello"); - } + #[test] + fn handles_malformed_sequences() { + // Incomplete escape sequence - our function strips the \x1b char + assert_eq!(strip_ansi_codes("\x1b hello"), "hello"); } +} - mod get_task_name { - use super::*; - - #[test] - fn extracts_sub_recipe_name() { - let sub_recipe_task = Task { - id: "task_1".to_string(), - task_type: "sub_recipe".to_string(), - payload: json!({ - "sub_recipe": { - "name": "my_recipe", - "recipe_path": "/path/to/recipe" - } - }), - }; - - let task_info = TaskInfo { - task: sub_recipe_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; - - assert_eq!(get_task_name(&task_info), "my_recipe"); - } +mod get_task_name { + use super::*; - #[test] - fn falls_back_to_task_id_for_text_instruction() { - let text_task = Task { - id: "task_2".to_string(), - task_type: "text_instruction".to_string(), - payload: json!({"text_instruction": "do something"}), - }; - - let task_info = TaskInfo { - task: text_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; - - assert_eq!(get_task_name(&task_info), "task_2"); - } + #[test] + fn extracts_sub_recipe_name() { + let sub_recipe_task = Task { + id: "task_1".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({ + "sub_recipe": { + "name": "my_recipe", + "recipe_path": "/path/to/recipe" + } + }), + }; + + let task_info = TaskInfo { + task: sub_recipe_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "my_recipe"); + } - #[test] - fn falls_back_to_task_id_when_sub_recipe_name_missing() { - let malformed_task = Task { - id: "task_3".to_string(), - task_type: "sub_recipe".to_string(), - payload: json!({ - "sub_recipe": { - "recipe_path": "/path/to/recipe" - // missing "name" field - } - }), - }; - - let task_info = TaskInfo { - task: malformed_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; - - assert_eq!(get_task_name(&task_info), "task_3"); - } + #[test] + fn falls_back_to_task_id_for_text_instruction() { + let text_task = Task { + id: "task_2".to_string(), + task_type: "text_instruction".to_string(), + payload: json!({"text_instruction": "do something"}), + }; + + let task_info = TaskInfo { + task: text_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "task_2"); + } - #[test] - fn falls_back_to_task_id_when_sub_recipe_missing() { - let malformed_task = Task { - id: "task_4".to_string(), - task_type: "sub_recipe".to_string(), - payload: json!({}), // missing "sub_recipe" field - }; - - let task_info = TaskInfo { - task: malformed_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; - - assert_eq!(get_task_name(&task_info), "task_4"); - } + #[test] + fn falls_back_to_task_id_when_sub_recipe_name_missing() { + let malformed_task = Task { + id: "task_3".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({ + "sub_recipe": { + "recipe_path": "/path/to/recipe" + // missing "name" field + } + }), + }; + + let task_info = TaskInfo { + task: malformed_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "task_3"); } - mod count_by_status { - use super::*; - - fn create_test_task(id: &str, status: TaskStatus) -> TaskInfo { - TaskInfo { - task: Task { - id: id.to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - } - } + #[test] + fn falls_back_to_task_id_when_sub_recipe_missing() { + let malformed_task = Task { + id: "task_4".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({}), // missing "sub_recipe" field + }; + + let task_info = TaskInfo { + task: malformed_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "task_4"); + } +} - #[test] - fn counts_empty_map() { - let tasks = HashMap::new(); - let (total, pending, running, completed, failed) = count_by_status(&tasks); - assert_eq!( - (total, pending, running, completed, failed), - (0, 0, 0, 0, 0) - ); - } +mod count_by_status { + use super::*; - #[test] - fn counts_single_status() { - let mut tasks = HashMap::new(); - tasks.insert( - "task1".to_string(), - create_test_task("task1", TaskStatus::Pending), - ); - tasks.insert( - "task2".to_string(), - create_test_task("task2", TaskStatus::Pending), - ); - - let (total, pending, running, completed, failed) = count_by_status(&tasks); - assert_eq!( - (total, pending, running, completed, failed), - (2, 2, 0, 0, 0) - ); + fn create_test_task(id: &str, status: TaskStatus) -> TaskInfo { + TaskInfo { + task: Task { + id: id.to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), } + } - #[test] - fn counts_mixed_statuses() { - let mut tasks = HashMap::new(); - tasks.insert( - "task1".to_string(), - create_test_task("task1", TaskStatus::Pending), - ); - tasks.insert( - "task2".to_string(), - create_test_task("task2", TaskStatus::Running), - ); - tasks.insert( - "task3".to_string(), - create_test_task("task3", TaskStatus::Completed), - ); - tasks.insert( - "task4".to_string(), - create_test_task("task4", TaskStatus::Failed), - ); - tasks.insert( - "task5".to_string(), - create_test_task("task5", TaskStatus::Completed), - ); - - let (total, pending, running, completed, failed) = count_by_status(&tasks); - assert_eq!( - (total, pending, running, completed, failed), - (5, 1, 1, 2, 1) - ); - } + #[test] + fn counts_empty_map() { + let tasks = HashMap::new(); + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (0, 0, 0, 0, 0) + ); } - mod get_status_icon { - use super::*; + #[test] + fn counts_single_status() { + let mut tasks = HashMap::new(); + tasks.insert( + "task1".to_string(), + create_test_task("task1", TaskStatus::Pending), + ); + tasks.insert( + "task2".to_string(), + create_test_task("task2", TaskStatus::Pending), + ); + + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (2, 2, 0, 0, 0) + ); + } - #[test] - fn returns_correct_icon_for_pending() { - assert_eq!(get_status_icon(&TaskStatus::Pending), "⏳"); - } + #[test] + fn counts_mixed_statuses() { + let mut tasks = HashMap::new(); + tasks.insert( + "task1".to_string(), + create_test_task("task1", TaskStatus::Pending), + ); + tasks.insert( + "task2".to_string(), + create_test_task("task2", TaskStatus::Running), + ); + tasks.insert( + "task3".to_string(), + create_test_task("task3", TaskStatus::Completed), + ); + tasks.insert( + "task4".to_string(), + create_test_task("task4", TaskStatus::Failed), + ); + tasks.insert( + "task5".to_string(), + create_test_task("task5", TaskStatus::Completed), + ); + + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (5, 1, 1, 2, 1) + ); + } +} - #[test] - fn returns_correct_icon_for_running() { - assert_eq!(get_status_icon(&TaskStatus::Running), "🏃"); - } +mod get_status_icon { + use super::*; - #[test] - fn returns_correct_icon_for_completed() { - assert_eq!(get_status_icon(&TaskStatus::Completed), "✅"); - } + #[test] + fn returns_correct_icon_for_pending() { + assert_eq!(get_status_icon(&TaskStatus::Pending), "⏳"); + } - #[test] - fn returns_correct_icon_for_failed() { - assert_eq!(get_status_icon(&TaskStatus::Failed), "❌"); - } + #[test] + fn returns_correct_icon_for_running() { + assert_eq!(get_status_icon(&TaskStatus::Running), "🏃"); } - mod process_output_lines { - use super::*; + #[test] + fn returns_correct_icon_for_completed() { + assert_eq!(get_status_icon(&TaskStatus::Completed), "✅"); + } - #[test] - fn preserves_short_output() { - let output = "line 1\nline 2"; - assert_eq!(process_output_lines(output), "line 1\nline 2"); - } + #[test] + fn returns_correct_icon_for_failed() { + assert_eq!(get_status_icon(&TaskStatus::Failed), "❌"); + } +} - #[test] - fn keeps_only_recent_lines_when_too_many() { - let output = "line 1\nline 2\nline 3\nline 4\nline 5"; - let result = process_output_lines(output); - assert_eq!(result, "line 4\nline 5"); - } +mod process_output_lines { + use super::*; - #[test] - fn strips_ansi_codes_from_output() { - let output = "\x1b[31mred line 1\x1b[0m\n\x1b[32mgreen line 2\x1b[0m"; - let result = process_output_lines(output); - assert_eq!(result, "red line 1\ngreen line 2"); - } + #[test] + fn preserves_short_output() { + let output = "line 1\nline 2"; + assert_eq!(process_output_lines(output), "line 1\nline 2"); + } - #[test] - fn handles_empty_output() { - assert_eq!(process_output_lines(""), ""); - } + #[test] + fn keeps_only_recent_lines_when_too_many() { + let output = "line 1\nline 2\nline 3\nline 4\nline 5"; + let result = process_output_lines(output); + assert_eq!(result, "line 4\nline 5"); + } - #[test] - fn handles_single_line() { - assert_eq!(process_output_lines("single line"), "single line"); - } + #[test] + fn strips_ansi_codes_from_output() { + let output = "\x1b[31mred line 1\x1b[0m\n\x1b[32mgreen line 2\x1b[0m"; + let result = process_output_lines(output); + assert_eq!(result, "red line 1\ngreen line 2"); + } - #[test] - fn combines_ansi_stripping_and_line_limiting() { - let output = "\x1b[31mline 1\x1b[0m\n\x1b[32mline 2\x1b[0m\n\x1b[33mline 3\x1b[0m\n\x1b[34mline 4\x1b[0m"; - let result = process_output_lines(output); - assert_eq!(result, "line 3\nline 4"); - } + #[test] + fn handles_empty_output() { + assert_eq!(process_output_lines(""), ""); } - mod format_task_timing { - use super::*; - use std::time::Duration; - - fn create_test_task_info_with_timing( - start: Option, - end: Option, - ) -> TaskInfo { - TaskInfo { - task: Task { - id: "test_task".to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status: TaskStatus::Running, - start_time: start, - end_time: end, - result: None, - current_output: String::new(), - } - } + #[test] + fn handles_single_line() { + assert_eq!(process_output_lines("single line"), "single line"); + } - #[test] - fn returns_none_when_no_start_time() { - let task_info = create_test_task_info_with_timing(None, None); - let current_time = Instant::now(); - assert!(format_task_timing(&task_info, current_time).is_none()); - } + #[test] + fn combines_ansi_stripping_and_line_limiting() { + let output = "\x1b[31mline 1\x1b[0m\n\x1b[32mline 2\x1b[0m\n\x1b[33mline 3\x1b[0m\n\x1b[34mline 4\x1b[0m"; + let result = process_output_lines(output); + assert_eq!(result, "line 3\nline 4"); + } +} - #[test] - fn formats_running_task_duration() { - let start_time = Instant::now(); - let current_time = start_time + Duration::from_millis(1500); - let task_info = create_test_task_info_with_timing(Some(start_time), None); - - let result = format_task_timing(&task_info, current_time).unwrap(); - assert!(result.contains("1.5s")); - assert!(result.contains("⏱️")); +mod format_task_timing { + use super::*; + use std::time::Duration; + + fn create_test_task_info_with_timing(start: Option, end: Option) -> TaskInfo { + TaskInfo { + task: Task { + id: "test_task".to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status: TaskStatus::Running, + start_time: start, + end_time: end, + result: None, + current_output: String::new(), } + } - #[test] - fn formats_completed_task_duration() { - let start_time = Instant::now(); - let end_time = start_time + Duration::from_millis(2500); - let current_time = Instant::now(); // This shouldn't matter for completed tasks - let task_info = create_test_task_info_with_timing(Some(start_time), Some(end_time)); - - let result = format_task_timing(&task_info, current_time).unwrap(); - assert!(result.contains("2.5s")); - assert!(result.contains("⏱️")); - } + #[test] + fn returns_none_when_no_start_time() { + let task_info = create_test_task_info_with_timing(None, None); + let current_time = Instant::now(); + assert!(format_task_timing(&task_info, current_time).is_none()); } - mod format_task_output { - use super::*; - - fn create_test_task_info_with_output(status: TaskStatus, output: &str) -> TaskInfo { - TaskInfo { - task: Task { - id: "test_task".to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status, - start_time: None, - end_time: None, - result: None, - current_output: output.to_string(), - } - } + #[test] + fn formats_running_task_duration() { + let start_time = Instant::now(); + let current_time = start_time + Duration::from_millis(1500); + let task_info = create_test_task_info_with_timing(Some(start_time), None); - #[test] - fn returns_none_for_non_running_tasks() { - let task_info = create_test_task_info_with_output(TaskStatus::Pending, "some output"); - assert!(format_task_output(&task_info).is_none()); + let result = format_task_timing(&task_info, current_time).unwrap(); + assert!(result.contains("1.5s")); + assert!(result.contains("⏱️")); + } - let task_info = create_test_task_info_with_output(TaskStatus::Completed, "some output"); - assert!(format_task_output(&task_info).is_none()); + #[test] + fn formats_completed_task_duration() { + let start_time = Instant::now(); + let end_time = start_time + Duration::from_millis(2500); + let current_time = Instant::now(); // This shouldn't matter for completed tasks + let task_info = create_test_task_info_with_timing(Some(start_time), Some(end_time)); - let task_info = create_test_task_info_with_output(TaskStatus::Failed, "some output"); - assert!(format_task_output(&task_info).is_none()); - } + let result = format_task_timing(&task_info, current_time).unwrap(); + assert!(result.contains("2.5s")); + assert!(result.contains("⏱️")); + } +} - #[test] - fn returns_none_for_running_task_with_empty_output() { - let task_info = create_test_task_info_with_output(TaskStatus::Running, ""); - assert!(format_task_output(&task_info).is_none()); - } +mod format_task_output { + use super::*; - #[test] - fn formats_running_task_with_output() { - let task_info = create_test_task_info_with_output(TaskStatus::Running, "Building project..."); - let result = format_task_output(&task_info).unwrap(); - - assert!(result.contains("💬")); - assert!(result.contains("Building project...")); + fn create_test_task_info_with_output(status: TaskStatus, output: &str) -> TaskInfo { + TaskInfo { + task: Task { + id: "test_task".to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status, + start_time: None, + end_time: None, + result: None, + current_output: output.to_string(), } + } - #[test] - fn replaces_newlines_with_pipes() { - let task_info = create_test_task_info_with_output(TaskStatus::Running, "line 1\nline 2\nline 3"); - let result = format_task_output(&task_info).unwrap(); - - assert!(result.contains("line 1 | line 2 | line 3")); - } + #[test] + fn returns_none_for_non_running_tasks() { + let task_info = create_test_task_info_with_output(TaskStatus::Pending, "some output"); + assert!(format_task_output(&task_info).is_none()); - #[test] - fn truncates_long_output() { - let long_output = "a".repeat(150); - let task_info = create_test_task_info_with_output(TaskStatus::Running, &long_output); - let result = format_task_output(&task_info).unwrap(); - - assert!(result.contains("...")); - assert!(result.len() < long_output.len() + 20); // Account for formatting - } + let task_info = create_test_task_info_with_output(TaskStatus::Completed, "some output"); + assert!(format_task_output(&task_info).is_none()); + + let task_info = create_test_task_info_with_output(TaskStatus::Failed, "some output"); + assert!(format_task_output(&task_info).is_none()); } - mod format_task_error { - use super::*; - use crate::agents::sub_recipe_execution_tool::types::{TaskResult, TaskStatus}; + #[test] + fn returns_none_for_running_task_with_empty_output() { + let task_info = create_test_task_info_with_output(TaskStatus::Running, ""); + assert!(format_task_output(&task_info).is_none()); + } - fn create_test_task_info_with_error(error_msg: Option<&str>) -> TaskInfo { - let result = error_msg.map(|msg| TaskResult { - task_id: "test_task".to_string(), - status: TaskStatus::Failed, - data: None, - error: Some(msg.to_string()), - }); - - TaskInfo { - task: Task { - id: "test_task".to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status: TaskStatus::Failed, - start_time: None, - end_time: None, - result, - current_output: String::new(), - } - } + #[test] + fn formats_running_task_with_output() { + let task_info = + create_test_task_info_with_output(TaskStatus::Running, "Building project..."); + let result = format_task_output(&task_info).unwrap(); - #[test] - fn returns_none_when_no_error() { - let task_info = create_test_task_info_with_error(None); - assert!(format_task_error(&task_info).is_none()); - } + assert!(result.contains("💬")); + assert!(result.contains("Building project...")); + } - #[test] - fn formats_error_message() { - let task_info = create_test_task_info_with_error(Some("File not found")); - let result = format_task_error(&task_info).unwrap(); - - assert!(result.contains("⚠️")); - assert!(result.contains("File not found")); - } + #[test] + fn replaces_newlines_with_pipes() { + let task_info = + create_test_task_info_with_output(TaskStatus::Running, "line 1\nline 2\nline 3"); + let result = format_task_output(&task_info).unwrap(); - #[test] - fn replaces_newlines_in_error() { - let task_info = create_test_task_info_with_error(Some("Error on line 1\nError on line 2")); - let result = format_task_error(&task_info).unwrap(); - - assert!(result.contains("Error on line 1 Error on line 2")); - } + assert!(result.contains("line 1 | line 2 | line 3")); + } - #[test] - fn truncates_long_error() { - let long_error = "error ".repeat(30); - let task_info = create_test_task_info_with_error(Some(&long_error)); - let result = format_task_error(&task_info).unwrap(); - - assert!(result.contains("...")); - assert!(result.len() < long_error.len() + 20); // Account for formatting + #[test] + fn truncates_long_output() { + let long_output = "a".repeat(150); + let task_info = create_test_task_info_with_output(TaskStatus::Running, &long_output); + let result = format_task_output(&task_info).unwrap(); + + assert!(result.contains("...")); + assert!(result.len() < long_output.len() + 20); // Account for formatting + } +} + +mod format_task_error { + use super::*; + use crate::agents::sub_recipe_execution_tool::types::{TaskResult, TaskStatus}; + + fn create_test_task_info_with_error(error_msg: Option<&str>) -> TaskInfo { + let result = error_msg.map(|msg| TaskResult { + task_id: "test_task".to_string(), + status: TaskStatus::Failed, + data: None, + error: Some(msg.to_string()), + }); + + TaskInfo { + task: Task { + id: "test_task".to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status: TaskStatus::Failed, + start_time: None, + end_time: None, + result, + current_output: String::new(), } } - mod format_task_display { - use super::*; - use std::time::Duration; + #[test] + fn returns_none_when_no_error() { + let task_info = create_test_task_info_with_error(None); + assert!(format_task_error(&task_info).is_none()); + } + + #[test] + fn formats_error_message() { + let task_info = create_test_task_info_with_error(Some("File not found")); + let result = format_task_error(&task_info).unwrap(); + + assert!(result.contains("⚠️")); + assert!(result.contains("File not found")); + } + + #[test] + fn replaces_newlines_in_error() { + let task_info = create_test_task_info_with_error(Some("Error on line 1\nError on line 2")); + let result = format_task_error(&task_info).unwrap(); + + assert!(result.contains("Error on line 1 Error on line 2")); + } - fn create_comprehensive_task_info( - task_name: &str, - status: TaskStatus, - start_time: Option, - end_time: Option, - current_output: &str, - error: Option<&str>, - ) -> TaskInfo { - let result = error.map(|msg| crate::agents::sub_recipe_execution_tool::types::TaskResult { + #[test] + fn truncates_long_error() { + let long_error = "error ".repeat(30); + let task_info = create_test_task_info_with_error(Some(&long_error)); + let result = format_task_error(&task_info).unwrap(); + + assert!(result.contains("...")); + assert!(result.len() < long_error.len() + 20); // Account for formatting + } +} + +mod format_task_display { + use super::*; + use std::time::Duration; + + fn create_comprehensive_task_info( + task_name: &str, + status: TaskStatus, + start_time: Option, + end_time: Option, + current_output: &str, + error: Option<&str>, + ) -> TaskInfo { + let result = error.map( + |msg| crate::agents::sub_recipe_execution_tool::types::TaskResult { task_id: task_name.to_string(), status: status.clone(), data: None, error: Some(msg.to_string()), - }); - - TaskInfo { - task: Task { - id: task_name.to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status, - start_time, - end_time, - result, - current_output: current_output.to_string(), - } + }, + ); + + TaskInfo { + task: Task { + id: task_name.to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status, + start_time, + end_time, + result, + current_output: current_output.to_string(), } + } - #[test] - fn formats_pending_task() { - let task_info = create_comprehensive_task_info( - "pending_task", - TaskStatus::Pending, - None, - None, - "", - None, - ); - let current_time = Instant::now(); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("⏳")); - assert!(result.contains("pending_task")); - assert!(result.contains("(test)")); - } + #[test] + fn formats_pending_task() { + let task_info = create_comprehensive_task_info( + "pending_task", + TaskStatus::Pending, + None, + None, + "", + None, + ); + let current_time = Instant::now(); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("⏳")); + assert!(result.contains("pending_task")); + assert!(result.contains("(test)")); + } - #[test] - fn formats_running_task_with_output() { - let start_time = Instant::now(); - let current_time = start_time + Duration::from_secs(2); - let task_info = create_comprehensive_task_info( - "running_task", - TaskStatus::Running, - Some(start_time), - None, - "Compiling...", - None, - ); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("🏃")); - assert!(result.contains("running_task")); - assert!(result.contains("2.0s")); - assert!(result.contains("💬")); - assert!(result.contains("Compiling...")); - } + #[test] + fn formats_running_task_with_output() { + let start_time = Instant::now(); + let current_time = start_time + Duration::from_secs(2); + let task_info = create_comprehensive_task_info( + "running_task", + TaskStatus::Running, + Some(start_time), + None, + "Compiling...", + None, + ); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("🏃")); + assert!(result.contains("running_task")); + assert!(result.contains("2.0s")); + assert!(result.contains("💬")); + assert!(result.contains("Compiling...")); + } - #[test] - fn formats_failed_task_with_error() { - let start_time = Instant::now(); - let end_time = start_time + Duration::from_millis(1500); - let task_info = create_comprehensive_task_info( - "failed_task", - TaskStatus::Failed, - Some(start_time), - Some(end_time), - "", - Some("Compilation failed"), - ); - let current_time = Instant::now(); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("❌")); - assert!(result.contains("failed_task")); - assert!(result.contains("1.5s")); - assert!(result.contains("⚠️")); - assert!(result.contains("Compilation failed")); - } + #[test] + fn formats_failed_task_with_error() { + let start_time = Instant::now(); + let end_time = start_time + Duration::from_millis(1500); + let task_info = create_comprehensive_task_info( + "failed_task", + TaskStatus::Failed, + Some(start_time), + Some(end_time), + "", + Some("Compilation failed"), + ); + let current_time = Instant::now(); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("❌")); + assert!(result.contains("failed_task")); + assert!(result.contains("1.5s")); + assert!(result.contains("⚠️")); + assert!(result.contains("Compilation failed")); + } - #[test] - fn formats_completed_task() { - let start_time = Instant::now(); - let end_time = start_time + Duration::from_secs(3); - let task_info = create_comprehensive_task_info( - "completed_task", - TaskStatus::Completed, - Some(start_time), - Some(end_time), - "", - None, - ); - let current_time = Instant::now(); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("✅")); - assert!(result.contains("completed_task")); - assert!(result.contains("3.0s")); - } + #[test] + fn formats_completed_task() { + let start_time = Instant::now(); + let end_time = start_time + Duration::from_secs(3); + let task_info = create_comprehensive_task_info( + "completed_task", + TaskStatus::Completed, + Some(start_time), + Some(end_time), + "", + None, + ); + let current_time = Instant::now(); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("✅")); + assert!(result.contains("completed_task")); + assert!(result.contains("3.0s")); } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index d9f750480432..73ea8b5ab372 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -5,6 +5,10 @@ use std::sync::Arc; #[cfg(test)] mod tests { + use std::sync::atomic::AtomicUsize; + + use tokio::sync::mpsc; + use super::*; use crate::agents::sub_recipe_execution_tool::types::Task; @@ -18,8 +22,6 @@ mod tests { task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), result_sender: result_tx, active_workers: Arc::new(AtomicUsize::new(0)), - should_stop: Arc::new(AtomicBool::new(false)), - completed_tasks: Arc::new(AtomicUsize::new(0)), dashboard: None, }); @@ -30,7 +32,6 @@ mod tests { assert!(!handle.is_finished()); // Signal stop and close the channel to let the worker exit - shared_state.should_stop.store(true, Ordering::SeqCst); drop(task_tx); // Close the channel // Wait for the worker to finish From f7832da99126d25f8f9c2dee0bc4545aaf10059d Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 8 Jul 2025 21:35:10 +1000 Subject: [PATCH 09/43] display param values in the task list --- .../agents/sub_recipe_execution_tool/tasks.rs | 34 ++++++------ .../agents/sub_recipe_execution_tool/types.rs | 40 +++++++++++++- .../sub_recipe_execution_tool/utils/mod.rs | 54 ++++++++++++------- 3 files changed, 93 insertions(+), 35 deletions(-) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index 9543f0cec504..ceac281937e2 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -58,29 +58,31 @@ async fn execute_task(task: Task, dashboard: Option>) -> Resu fn build_command(task: &Task) -> Result<(Command, String), String> { let mut output_identifier = task.id.clone(); let mut command = if task.task_type == "sub_recipe" { - let sub_recipe = task.payload.get("sub_recipe").unwrap(); - let sub_recipe_name = sub_recipe.get("name").unwrap().as_str().unwrap(); - let path = sub_recipe.get("recipe_path").unwrap().as_str().unwrap(); - let command_parameters = sub_recipe.get("command_parameters").unwrap(); + let sub_recipe_name = task + .get_sub_recipe_name() + .ok_or("Missing sub_recipe name")?; + let path = task + .get_sub_recipe_path() + .ok_or("Missing sub_recipe path")?; + let command_parameters = task + .get_command_parameters() + .ok_or("Missing command_parameters")?; + output_identifier = format!("sub-recipe {}", sub_recipe_name); let mut cmd = Command::new("goose"); cmd.arg("run").arg("--recipe").arg(path); - if let Some(params_map) = command_parameters.as_object() { - for (key, value) in params_map { - let key_str = key.to_string(); - let value_str = value.as_str().unwrap_or(&value.to_string()).to_string(); - cmd.arg("--params") - .arg(format!("{}={}", key_str, value_str)); - } + + for (key, value) in command_parameters { + let key_str = key.to_string(); + let value_str = value.as_str().unwrap_or(&value.to_string()).to_string(); + cmd.arg("--params") + .arg(format!("{}={}", key_str, value_str)); } cmd } else { let text = task - .payload - .get("text_instruction") - .unwrap() - .as_str() - .unwrap(); + .get_text_instruction() + .ok_or("Missing text_instruction")?; let mut cmd = Command::new("goose"); cmd.arg("run").arg("--text").arg(text); cmd diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index 2f40490a564c..ed53c7ea08e3 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde_json::{Map, Value}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::sync::mpsc; @@ -13,6 +13,44 @@ pub struct Task { pub payload: Value, } +impl Task { + pub fn get_sub_recipe(&self) -> Option<&Map> { + if self.task_type == "sub_recipe" { + self.payload.get("sub_recipe").and_then(|sr| sr.as_object()) + } else { + None + } + } + + pub fn get_command_parameters(&self) -> Option<&Map> { + self.get_sub_recipe() + .and_then(|sr| sr.get("command_parameters")) + .and_then(|cp| cp.as_object()) + } + + pub fn get_sub_recipe_name(&self) -> Option<&str> { + self.get_sub_recipe() + .and_then(|sr| sr.get("name")) + .and_then(|name| name.as_str()) + } + + pub fn get_sub_recipe_path(&self) -> Option<&str> { + self.get_sub_recipe() + .and_then(|sr| sr.get("recipe_path")) + .and_then(|path| path.as_str()) + } + + pub fn get_text_instruction(&self) -> Option<&str> { + if self.task_type != "sub_recipe" { + self.payload + .get("text_instruction") + .and_then(|text| text.as_str()) + } else { + None + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TaskResult { pub task_id: String, diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index 31f20273ec9e..114febb71ff3 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -1,3 +1,4 @@ +use serde_json::{Map, Value}; use std::collections::HashMap; use tokio::time::Instant; @@ -10,17 +11,14 @@ const ERROR_PREVIEW_LENGTH: usize = 80; const CLEAR_TO_EOL: &str = "\x1b[K"; pub fn get_task_name(task_info: &TaskInfo) -> &str { - if task_info.task.task_type == "sub_recipe" { - task_info - .task - .payload - .get("sub_recipe") - .and_then(|sr| sr.get("name")) - .and_then(|n| n.as_str()) - .unwrap_or(&task_info.task.id) - } else { - &task_info.task.id - } + task_info + .task + .get_sub_recipe_name() + .unwrap_or(&task_info.task.id) +} + +pub fn get_command_parameters(task_info: &TaskInfo) -> Option<&Map> { + task_info.task.get_command_parameters() } pub fn truncate_with_ellipsis(text: &str, max_len: usize) -> String { @@ -68,9 +66,7 @@ pub fn strip_ansi_codes(text: &str) -> String { result } -// Pure utility functions for dashboard rendering -/// Get status icon for a given task status pub fn get_status_icon(status: &TaskStatus) -> &'static str { match status { TaskStatus::Pending => "⏳", @@ -80,7 +76,6 @@ pub fn get_status_icon(status: &TaskStatus) -> &'static str { } } -/// Process output lines, keeping only recent lines and stripping ANSI codes pub fn process_output_lines(output: &str) -> String { let lines: Vec<&str> = output.lines().collect(); let recent_lines = if lines.len() > MAX_OUTPUT_LINES { @@ -139,7 +134,29 @@ pub fn format_task_error(task_info: &TaskInfo) -> Option { }) } -/// Format complete task display +pub fn format_command_parameters(task_info: &TaskInfo) -> Option { + get_command_parameters(task_info).map(|params| { + if params.is_empty() { + return format!(" 📋 Parameters: (none){}\n", CLEAR_TO_EOL); + } + + let params_str = params + .iter() + .map(|(key, value)| { + let value_str = match value { + Value::String(s) => s.clone(), + _ => value.to_string(), + }; + format!("{}={}", key, value_str) + }) + .collect::>() + .join(", "); + + let params_preview = truncate_with_ellipsis(¶ms_str, OUTPUT_PREVIEW_LENGTH); + format!(" 📋 Parameters: {}{}\n", params_preview, CLEAR_TO_EOL) + }) +} + pub fn format_task_display(task_info: &TaskInfo, current_time: Instant) -> String { let mut display = String::new(); @@ -153,12 +170,14 @@ pub fn format_task_display(task_info: &TaskInfo, current_time: Instant) -> Strin status_icon, task_name, task_info.task.task_type, CLEAR_TO_EOL )); - // Task timing + if let Some(params) = format_command_parameters(task_info) { + display.push_str(¶ms); + } + if let Some(timing) = format_task_timing(task_info, current_time) { display.push_str(&timing); } - // Task output (if running) if let Some(output) = format_task_output(task_info) { display.push_str(&output); } @@ -168,7 +187,6 @@ pub fn format_task_display(task_info: &TaskInfo, current_time: Instant) -> Strin display.push_str(&error); } - // Empty line display.push_str(&format!( "{} ", From e601ca2a28dc3201964e90bfbb5e475e5857d3eb Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 02:06:50 +1000 Subject: [PATCH 10/43] use stream to send the output to session --- crates/goose-cli/src/session/mod.rs | 21 ++- .../sub_recipe_execution_tool/dashboard.rs | 146 ++++++++++++----- .../sub_recipe_execution_tool/executor.rs | 155 ++++++++---------- .../agents/sub_recipe_execution_tool/lib.rs | 18 +- .../sub_recipe_execute_task_tool.rs | 40 +++-- .../agents/sub_recipe_execution_tool/tasks.rs | 24 ++- .../agents/sub_recipe_execution_tool/types.rs | 2 +- .../sub_recipe_execution_tool/utils/mod.rs | 1 - .../sub_recipe_execution_tool/workers.rs | 61 +------ crates/goose/src/agents/sub_recipe_manager.rs | 25 --- 10 files changed, 253 insertions(+), 240 deletions(-) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index e978fbe12beb..6efcf16187ee 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -15,6 +15,7 @@ use goose::permission::Permission; use goose::permission::PermissionConfirmation; use goose::providers::base::Provider; pub use goose::session::Identifier; +use std::io::Write; use anyhow::{Context, Result}; use completion::GooseCompleter; @@ -1011,8 +1012,11 @@ impl Session { }; (formatted, Some(subagent_id.to_string()), Some(notification_type.to_string())) } else if let Some(Value::String(output)) = o.get("output") { - // Fallback for other MCP notification types + // Shell tool notification (output.to_owned(), None, None) + } else if let Some(Value::String(display)) = o.get("display") { + // Dashboard notification - return raw display content with ANSI codes + (display.to_owned(), None, Some("dashboard".to_string())) } else { (data.to_string(), None, None) } @@ -1022,6 +1026,21 @@ impl Session { }, }; + // Handle dashboard notifications specially - print raw content with ANSI codes + if let Some(ref notification_type) = _notification_type { + if notification_type == "dashboard" { + if interactive { + let _ = progress_bars.hide(); + print!("{}", formatted_message); + std::io::stdout().flush().unwrap(); + } else { + print!("{}", formatted_message); + std::io::stdout().flush().unwrap(); + } + continue; // Skip the normal notification handling below + } + } + // Handle subagent notifications - show immediately if let Some(_id) = subagent_id { // Show subagent notifications immediately (no buffering) with compact spacing diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs index ee8ae16bd4b5..572c82446073 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs @@ -1,14 +1,21 @@ +use mcp_core::protocol::{JsonRpcMessage, JsonRpcNotification}; +use serde_json::json; use std::collections::HashMap; -use std::io::{self, Write}; use std::sync::Arc; -use tokio::sync::RwLock; +use tokio::sync::{mpsc, RwLock}; use tokio::time::{sleep, Duration, Instant}; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskResult, TaskStatus}; use crate::agents::sub_recipe_execution_tool::utils::{ - count_by_status, format_task_display, get_task_name, process_output_lines, + count_by_status, format_task_display, get_task_name, }; +#[derive(Debug, Clone, PartialEq)] +pub enum DisplayMode { + Dashboard, + SingleTaskOutput, +} + const THROTTLE_INTERVAL_MS: u64 = 1000; const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H"; const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H"; @@ -20,10 +27,16 @@ pub struct TaskDashboard { last_display: Arc>, last_refresh: Arc>, initial_display_shown: Arc>, + notifier: mpsc::Sender, + display_mode: DisplayMode, } impl TaskDashboard { - pub fn new(tasks: Vec) -> Self { + pub fn new( + tasks: Vec, + display_mode: DisplayMode, + notifier: mpsc::Sender, + ) -> Self { let task_map = tasks .into_iter() .map(|task| { @@ -47,6 +60,8 @@ impl TaskDashboard { last_display: Arc::new(RwLock::new(String::new())), last_refresh: Arc::new(RwLock::new(Instant::now())), initial_display_shown: Arc::new(RwLock::new(false)), + notifier, + display_mode, } } @@ -71,15 +86,34 @@ impl TaskDashboard { self.refresh_display().await; } - pub async fn update_task_output(&self, task_id: &str, output: &str) { - let mut tasks = self.tasks.write().await; - if let Some(task_info) = tasks.get_mut(task_id) { - task_info.current_output = process_output_lines(output); - } - drop(tasks); + pub async fn send_live_output(&self, task_id: &str, line: &str) { + match self.display_mode { + DisplayMode::SingleTaskOutput => { + let _ = self + .notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": { + "type": "dashboard", + "display": format!("{}\n", line) + } + })), + })); + } + DisplayMode::Dashboard => { + let mut tasks = self.tasks.write().await; + if let Some(task_info) = tasks.get_mut(task_id) { + task_info.current_output.push_str(line); + task_info.current_output.push('\n'); + } + drop(tasks); - if !self.should_throttle_refresh().await { - self.refresh_display().await; + if !self.should_throttle_refresh().await { + self.refresh_display().await; + } + } } } @@ -123,64 +157,98 @@ impl TaskDashboard { async fn update_display_if_changed(&self, display: String) { let mut last_display = self.last_display.write().await; if *last_display != display { - print!("{}", display); - io::stdout().flush().unwrap(); + let _ = self + .notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": { + "type": "dashboard", + "display": display.clone() + } + })), + })); *last_display = display; } } pub async fn refresh_display(&self) { - let tasks = self.tasks.read().await; - let mut display = String::new(); + match self.display_mode { + DisplayMode::Dashboard => { + let tasks = self.tasks.read().await; + let mut display = String::new(); - let mut initial_shown = self.initial_display_shown.write().await; - self.render_header(&mut display, &mut initial_shown); - drop(initial_shown); + let mut initial_shown = self.initial_display_shown.write().await; + self.render_header(&mut display, &mut initial_shown); + drop(initial_shown); - self.render_progress_line(&mut display, &tasks); + self.render_progress_line(&mut display, &tasks); - let mut task_list: Vec<_> = tasks.values().collect(); - task_list.sort_by_key(|t| &t.task.id); + let mut task_list: Vec<_> = tasks.values().collect(); + task_list.sort_by_key(|t| &t.task.id); - for task_info in task_list { - self.render_task(&mut display, task_info); - } + for task_info in task_list { + self.render_task(&mut display, task_info); + } - display.push_str(CLEAR_BELOW); + display.push_str(CLEAR_BELOW); - self.update_display_if_changed(display).await; + self.update_display_if_changed(display).await; + } + DisplayMode::SingleTaskOutput => { + // No dashboard display needed for single task output mode + // Live output is handled via send_live_output method + } + } } pub async fn show_final_summary(&self) { let tasks = self.tasks.read().await; - println!("Execution Complete!"); - println!("═══════════════════════"); + let mut summary = String::new(); + summary.push_str("Execution Complete!\n"); + summary.push_str("═══════════════════════\n"); let (total, _, _, completed, failed) = count_by_status(&tasks); - println!("Total Tasks: {}", total); - println!("✅ Completed: {}", completed); - println!("❌ Failed: {}", failed); - println!( - "📈 Success Rate: {:.1}%", + summary.push_str(&format!("Total Tasks: {}\n", total)); + summary.push_str(&format!("✅ Completed: {}\n", completed)); + summary.push_str(&format!("❌ Failed: {}\n", failed)); + summary.push_str(&format!( + "📈 Success Rate: {:.1}%\n", (completed as f64 / total as f64) * 100.0 - ); + )); if failed > 0 { - println!("\n❌ Failed Tasks:"); + summary.push_str("\n❌ Failed Tasks:\n"); for task_info in tasks.values() { if matches!(task_info.status, TaskStatus::Failed) { let task_name = get_task_name(task_info); - println!(" • {}", task_name); + summary.push_str(&format!(" • {}\n", task_name)); if let Some(error) = task_info.error() { - println!(" Error: {}", error); + summary.push_str(&format!(" Error: {}\n", error)); } } } } - println!("\n📝 Generating summary..."); + summary.push_str("\n📝 Generating summary...\n"); + + // Send the final summary via notification + let _ = self + .notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": { + "type": "dashboard", + "display": summary + } + })), + })); + sleep(Duration::from_millis(500)).await; } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs index 9d869666e4af..3e1974323255 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs @@ -1,9 +1,10 @@ +use mcp_core::protocol::JsonRpcMessage; use std::sync::atomic::AtomicUsize; use std::sync::Arc; use tokio::sync::mpsc; use tokio::time::Instant; -use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; +use crate::agents::sub_recipe_execution_tool::dashboard::{DisplayMode, TaskDashboard}; use crate::agents::sub_recipe_execution_tool::lib::{ Config, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; @@ -12,10 +13,18 @@ use crate::agents::sub_recipe_execution_tool::workers::spawn_worker; const EXECUTION_STATUS_COMPLETED: &str = "completed"; -pub async fn execute_single_task(task: &Task, config: Config) -> ExecutionResponse { +pub async fn execute_single_task( + task: &Task, + config: Config, + notifier: mpsc::Sender, +) -> ExecutionResponse { let start_time = Instant::now(); - let result = process_task(task, config.timeout_seconds).await; - + let dashboard = Arc::new(TaskDashboard::new( + vec![task.clone()], + DisplayMode::SingleTaskOutput, + notifier, + )); + let result = process_task(task, config.timeout_seconds, dashboard).await; let execution_time = start_time.elapsed().as_millis(); let stats = calculate_stats(&[result.clone()], execution_time); @@ -26,6 +35,62 @@ pub async fn execute_single_task(task: &Task, config: Config) -> ExecutionRespon } } +pub async fn execute_tasks_in_parallel( + tasks: Vec, + config: Config, + notifier: mpsc::Sender, +) -> ExecutionResponse { + let dashboard = Arc::new(TaskDashboard::new( + tasks.clone(), + DisplayMode::Dashboard, + notifier, + )); + let start_time = Instant::now(); + let task_count = tasks.len(); + + if task_count == 0 { + return create_empty_response(); + } + + dashboard.refresh_display().await; + + let (task_tx, task_rx, result_tx, mut result_rx) = create_channels(task_count); + + if let Err(e) = send_tasks_to_channel(tasks, task_tx).await { + eprintln!("Execution failed: {}", e); + return create_error_response(e); + } + + let shared_state = create_shared_state(task_rx, result_tx, dashboard.clone()); + + // Simple static worker allocation - no dynamic scaling needed + let worker_count = std::cmp::min(task_count, config.max_workers); + let mut worker_handles = Vec::new(); + for i in 0..worker_count { + let handle = spawn_worker(shared_state.clone(), i, config.timeout_seconds); + worker_handles.push(handle); + } + + let results = collect_results(&mut result_rx, dashboard.clone(), task_count).await; + + for handle in worker_handles { + if let Err(e) = handle.await { + eprintln!("Worker error: {}", e); + } + } + + dashboard.show_final_summary().await; + + let execution_time = start_time.elapsed().as_millis(); + let stats = calculate_stats(&results, execution_time); + + ExecutionResponse { + status: EXECUTION_STATUS_COMPLETED.to_string(), + results, + stats, + } +} + fn calculate_stats(results: &[TaskResult], execution_time_ms: u128) -> ExecutionStats { let completed = results .iter() @@ -44,30 +109,6 @@ fn calculate_stats(results: &[TaskResult], execution_time_ms: u128) -> Execution } } -struct ExecutionContext { - tasks: Vec, - config: Config, - dashboard: Arc, - start_time: Instant, -} - -impl ExecutionContext { - fn new(tasks: Vec, config: Config) -> Self { - let dashboard = Arc::new(TaskDashboard::new(tasks.clone())); - - Self { - tasks, - config, - dashboard, - start_time: Instant::now(), - } - } - - fn task_count(&self) -> usize { - self.tasks.len() - } -} - fn create_channels( task_count: usize, ) -> ( @@ -90,7 +131,7 @@ fn create_shared_state( task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), result_sender: result_tx, active_workers: Arc::new(AtomicUsize::new(0)), - dashboard: Some(dashboard), + dashboard, }) } @@ -138,50 +179,6 @@ async fn collect_results( results } -async fn execute_with_context(ctx: ExecutionContext) -> Result { - let task_count = ctx.task_count(); - - if task_count == 0 { - return Ok(create_empty_response()); - } - - ctx.dashboard.refresh_display().await; - - let (task_tx, task_rx, result_tx, mut result_rx) = create_channels(task_count); - - send_tasks_to_channel(ctx.tasks, task_tx).await?; - - let shared_state = create_shared_state(task_rx, result_tx, ctx.dashboard.clone()); - - // Simple static worker allocation - no dynamic scaling needed - let worker_count = std::cmp::min(task_count, ctx.config.max_workers); - let mut worker_handles = Vec::new(); - for i in 0..worker_count { - let handle = spawn_worker(shared_state.clone(), i, ctx.config.timeout_seconds); - worker_handles.push(handle); - } - - let results = collect_results(&mut result_rx, ctx.dashboard.clone(), task_count).await; - - // Wait for all workers to finish - for handle in worker_handles { - if let Err(e) = handle.await { - eprintln!("Worker error: {}", e); - } - } - - ctx.dashboard.show_final_summary().await; - - let execution_time = ctx.start_time.elapsed().as_millis(); - let stats = calculate_stats(&results, execution_time); - - Ok(ExecutionResponse { - status: EXECUTION_STATUS_COMPLETED.to_string(), - results, - stats, - }) -} - fn create_error_response(_error: String) -> ExecutionResponse { ExecutionResponse { status: "failed".to_string(), @@ -194,15 +191,3 @@ fn create_error_response(_error: String) -> ExecutionResponse { }, } } - -pub async fn parallel_execute(tasks: Vec, config: Config) -> ExecutionResponse { - let ctx = ExecutionContext::new(tasks, config); - - match execute_with_context(ctx).await { - Ok(response) => response, - Err(e) => { - eprintln!("Execution failed: {}", e); - create_error_response(e) - } - } -} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs index 6c4718c2b3a0..5973025a1786 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs @@ -1,12 +1,19 @@ -use crate::agents::sub_recipe_execution_tool::executor::execute_single_task; -pub use crate::agents::sub_recipe_execution_tool::executor::parallel_execute; +use crate::agents::sub_recipe_execution_tool::executor::{ + execute_single_task, execute_tasks_in_parallel, +}; pub use crate::agents::sub_recipe_execution_tool::types::{ Config, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; +use mcp_core::protocol::JsonRpcMessage; use serde_json::Value; +use tokio::sync::mpsc; -pub async fn execute_tasks(input: Value, execution_mode: &str) -> Result { +pub async fn execute_tasks( + input: Value, + execution_mode: &str, + notifier: mpsc::Sender, +) -> Result { let tasks: Vec = serde_json::from_value(input.get("tasks").ok_or("Missing tasks field")?.clone()) .map_err(|e| format!("Failed to parse tasks: {}", e))?; @@ -17,11 +24,12 @@ pub async fn execute_tasks(input: Value, execution_mode: &str) -> Result { if task_count == 1 { - let response = execute_single_task(&tasks[0], config).await; + let response = execute_single_task(&tasks[0], config, notifier).await; serde_json::to_value(response) .map_err(|e| format!("Failed to serialize response: {}", e)) } else { @@ -29,7 +37,7 @@ pub async fn execute_tasks(input: Value, execution_mode: &str) -> Result { - let response = parallel_execute(tasks, config).await; + let response = execute_tasks_in_parallel(tasks, config, notifier).await; serde_json::to_value(response) .map_err(|e| format!("Failed to serialize response: {}", e)) } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 0e7e061fd4d1..870350a55da1 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -4,6 +4,9 @@ use serde_json::Value; use crate::agents::{ sub_recipe_execution_tool::lib::execute_tasks, tool_execution::ToolCallResult, }; +use mcp_core::protocol::JsonRpcMessage; +use tokio::sync::mpsc; +use tokio_stream; pub const SUB_RECIPE_EXECUTE_TASK_TOOL_NAME: &str = "sub_recipe__execute_task"; pub fn create_sub_recipe_execute_task_tool() -> Tool { @@ -119,18 +122,31 @@ Pre-created Task Based: } pub async fn run_tasks(execute_data: Value) -> ToolCallResult { - let execute_data_clone = execute_data.clone(); - let default_execution_mode_value = Value::String("sequential".to_string()); - let execution_mode = execute_data_clone - .get("execution_mode") - .unwrap_or(&default_execution_mode_value) - .as_str() - .unwrap_or("sequential"); - match execute_tasks(execute_data, execution_mode).await { - Ok(result) => { - let output = serde_json::to_string(&result).unwrap(); - ToolCallResult::from(Ok(vec![Content::text(output)])) + let (notification_tx, notification_rx) = mpsc::channel::(100); + + let result_future = async move { + let execute_data_clone = execute_data.clone(); + let default_execution_mode_value = Value::String("sequential".to_string()); + let execution_mode = execute_data_clone + .get("execution_mode") + .unwrap_or(&default_execution_mode_value) + .as_str() + .unwrap_or("sequential"); + + match execute_tasks(execute_data, execution_mode, notification_tx).await { + Ok(result) => { + let output = serde_json::to_string(&result).unwrap(); + Ok(vec![Content::text(output)]) + } + Err(e) => Err(ToolError::ExecutionError(e.to_string())), } - Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))), + }; + + // Convert receiver to stream + let notification_stream = tokio_stream::wrappers::ReceiverStream::new(notification_rx); + + ToolCallResult { + result: Box::new(Box::pin(result_future)), + notification_stream: Some(Box::new(notification_stream)), } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index ceac281937e2..969d885cedea 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -9,19 +9,15 @@ use tokio::time::timeout; use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult, TaskStatus}; -pub async fn process_task(task: &Task, timeout_seconds: u64) -> TaskResult { - process_task_with_dashboard(task, timeout_seconds, None).await -} - -pub async fn process_task_with_dashboard( +pub async fn process_task( task: &Task, timeout_seconds: u64, - dashboard: Option>, + dashboard: Arc, ) -> TaskResult { let task_clone = task.clone(); let timeout_duration = Duration::from_secs(timeout_seconds); - match timeout(timeout_duration, execute_task(task_clone, dashboard)).await { + match timeout(timeout_duration, get_task_result(task_clone, dashboard)).await { Ok(Ok(data)) => TaskResult { task_id: task.id.clone(), status: TaskStatus::Completed, @@ -43,7 +39,7 @@ pub async fn process_task_with_dashboard( } } -async fn execute_task(task: Task, dashboard: Option>) -> Result { +async fn get_task_result(task: Task, dashboard: Arc) -> Result { let (command, output_identifier) = build_command(&task)?; let (stdout_output, stderr_output, success) = run_command(command, &output_identifier, &task.id, dashboard).await?; @@ -97,7 +93,7 @@ async fn run_command( mut command: Command, output_identifier: &str, task_id: &str, - dashboard: Option>, + dashboard: Arc, ) -> Result<(String, String, bool), String> { let mut child = command .spawn() @@ -108,7 +104,8 @@ async fn run_command( let stdout_task = spawn_output_reader(stdout, output_identifier, false, task_id, dashboard.clone()); - let stderr_task = spawn_output_reader(stderr, output_identifier, true, task_id, None); + let stderr_task = + spawn_output_reader(stderr, output_identifier, true, task_id, dashboard.clone()); let status = child .wait() @@ -126,7 +123,7 @@ fn spawn_output_reader( output_identifier: &str, is_stderr: bool, task_id: &str, - dashboard: Option>, + dashboard: Arc, ) -> tokio::task::JoinHandle { let output_identifier = output_identifier.to_string(); let task_id = task_id.to_string(); @@ -138,9 +135,8 @@ fn spawn_output_reader( buffer.push('\n'); if !is_stderr { - if let Some(dashboard) = &dashboard { - dashboard.update_task_output(&task_id, &buffer).await; - } + // Use dashboard's smart output handling based on display mode + dashboard.send_live_output(&task_id, &line).await; } else { eprintln!("[stderr for {}] {}", output_identifier, line); } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index ed53c7ea08e3..f93c4f5c0c9a 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -93,7 +93,7 @@ pub struct SharedState { pub task_receiver: Arc>>, pub result_sender: mpsc::Sender, pub active_workers: Arc, - pub dashboard: Option>, + pub dashboard: Arc, } impl SharedState { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index 114febb71ff3..7b684126c383 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -66,7 +66,6 @@ pub fn strip_ansi_codes(text: &str) -> String { result } - pub fn get_status_icon(status: &TaskStatus) -> &'static str { match status { TaskStatus::Pending => "⏳", diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index 73ea8b5ab372..ea11a7e604cb 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -1,66 +1,12 @@ -use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; -use crate::agents::sub_recipe_execution_tool::tasks::{process_task, process_task_with_dashboard}; -use crate::agents::sub_recipe_execution_tool::types::{SharedState, Task, TaskResult}; +use crate::agents::sub_recipe_execution_tool::tasks::process_task; +use crate::agents::sub_recipe_execution_tool::types::{SharedState, Task}; use std::sync::Arc; -#[cfg(test)] -mod tests { - use std::sync::atomic::AtomicUsize; - - use tokio::sync::mpsc; - - use super::*; - use crate::agents::sub_recipe_execution_tool::types::Task; - - #[tokio::test] - async fn test_spawn_worker_returns_handle() { - // Create a simple shared state for testing - let (task_tx, task_rx) = mpsc::channel::(1); - let (result_tx, _result_rx) = mpsc::channel::(1); - - let shared_state = Arc::new(SharedState { - task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), - result_sender: result_tx, - active_workers: Arc::new(AtomicUsize::new(0)), - dashboard: None, - }); - - // Test that spawn_worker returns a JoinHandle - let handle = spawn_worker(shared_state.clone(), 0, 5); - - // Verify it's a JoinHandle by checking we can abort it - assert!(!handle.is_finished()); - - // Signal stop and close the channel to let the worker exit - drop(task_tx); // Close the channel - - // Wait for the worker to finish - let result = handle.await; - assert!(result.is_ok()); - } -} - async fn receive_task(state: &SharedState) -> Option { let mut receiver = state.task_receiver.lock().await; receiver.recv().await } -async fn execute_task( - task: Task, - timeout: u64, - dashboard: Option>, -) -> TaskResult { - if let Some(dashboard) = &dashboard { - dashboard.start_task(&task.id).await; - } - - if let Some(dashboard) = dashboard { - process_task_with_dashboard(&task, timeout, Some(dashboard)).await - } else { - process_task(&task, timeout).await - } -} - pub fn spawn_worker( state: Arc, worker_id: usize, @@ -75,7 +21,8 @@ pub fn spawn_worker( async fn worker_loop(state: Arc, _worker_id: usize, timeout_seconds: u64) { while let Some(task) = receive_task(&state).await { - let result = execute_task(task, timeout_seconds, state.dashboard.clone()).await; + state.dashboard.start_task(&task.id).await; + let result = process_task(&task, timeout_seconds, state.dashboard.clone()).await; if let Err(e) = state.result_sender.send(result).await { eprintln!("Worker failed to send result: {}", e); diff --git a/crates/goose/src/agents/sub_recipe_manager.rs b/crates/goose/src/agents/sub_recipe_manager.rs index d759914a9d9e..cb01c3ffe4dc 100644 --- a/crates/goose/src/agents/sub_recipe_manager.rs +++ b/crates/goose/src/agents/sub_recipe_manager.rs @@ -61,31 +61,6 @@ impl SubRecipeManager { } } - // async fn call_sub_recipe_tool( - // &self, - // tool_name: &str, - // params: Value, - // ) -> Result, ToolError> { - // let sub_recipe = self.sub_recipes.get(tool_name).ok_or_else(|| { - // let sub_recipe_name = tool_name - // .strip_prefix(SUB_RECIPE_TOOL_NAME_PREFIX) - // .and_then(|s| s.strip_prefix("_")) - // .ok_or_else(|| { - // ToolError::InvalidParameters(format!( - // "Invalid sub-recipe tool name format: {}", - // tool_name - // )) - // }) - // .unwrap(); - - // ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name)) - // })?; - - // let output = run_sub_recipe(sub_recipe, params).await.map_err(|e| { - // ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e)) - // })?; - // Ok(vec![Content::text(output)]) - // } async fn call_sub_recipe_tool( &self, tool_name: &str, From 4d11f373fd2bc889abf776d25a8655573b07687b Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 03:30:00 +1000 Subject: [PATCH 11/43] refactored the files --- crates/goose-cli/src/session/mod.rs | 39 ++-- .../src/session/task_execution_display.rs | 174 +++++++++++++++++ .../sub_recipe_execution_tool/executor.rs | 28 +-- .../agents/sub_recipe_execution_tool/mod.rs | 2 +- ...dashboard.rs => task_execution_tracker.rs} | 183 +++++++----------- .../agents/sub_recipe_execution_tool/tasks.rs | 51 +++-- .../agents/sub_recipe_execution_tool/types.rs | 4 +- .../sub_recipe_execution_tool/workers.rs | 5 +- 8 files changed, 325 insertions(+), 161 deletions(-) create mode 100644 crates/goose-cli/src/session/task_execution_display.rs rename crates/goose/src/agents/sub_recipe_execution_tool/{dashboard.rs => task_execution_tracker.rs} (50%) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 6efcf16187ee..920267351331 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -4,8 +4,11 @@ mod export; mod input; mod output; mod prompt; +mod task_execution_display; mod thinking; +use crate::session::task_execution_display::TASK_EXECUTION_NOTIFICATION_TYPE; + pub use self::export::message_to_markdown; pub use builder::{build_session, SessionBuilderConfig, SessionSettings}; use console::Color; @@ -16,6 +19,7 @@ use goose::permission::PermissionConfirmation; use goose::providers::base::Provider; pub use goose::session::Identifier; use std::io::Write; +use task_execution_display::format_task_execution_notification; use anyhow::{Context, Result}; use completion::GooseCompleter; @@ -1011,12 +1015,8 @@ impl Session { } }; (formatted, Some(subagent_id.to_string()), Some(notification_type.to_string())) - } else if let Some(Value::String(output)) = o.get("output") { - // Shell tool notification - (output.to_owned(), None, None) - } else if let Some(Value::String(display)) = o.get("display") { - // Dashboard notification - return raw display content with ANSI codes - (display.to_owned(), None, Some("dashboard".to_string())) + } else if let Some(result) = format_task_execution_notification(data) { + result } else { (data.to_string(), None, None) } @@ -1026,9 +1026,17 @@ impl Session { }, }; - // Handle dashboard notifications specially - print raw content with ANSI codes - if let Some(ref notification_type) = _notification_type { - if notification_type == "dashboard" { + // Handle subagent notifications - show immediately + if let Some(_id) = subagent_id { + // Show subagent notifications immediately (no buffering) with compact spacing + if interactive { + let _ = progress_bars.hide(); + println!("{}", console::style(&formatted_message).green().dim()); + } else { + progress_bars.log(&formatted_message); + } + } else if let Some(ref notification_type) = _notification_type { + if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE { if interactive { let _ = progress_bars.hide(); print!("{}", formatted_message); @@ -1037,20 +1045,9 @@ impl Session { print!("{}", formatted_message); std::io::stdout().flush().unwrap(); } - continue; // Skip the normal notification handling below } } - - // Handle subagent notifications - show immediately - if let Some(_id) = subagent_id { - // Show subagent notifications immediately (no buffering) with compact spacing - if interactive { - let _ = progress_bars.hide(); - println!("{}", console::style(&formatted_message).green().dim()); - } else { - progress_bars.log(&formatted_message); - } - } else { + else { // Non-subagent notification, display immediately with compact spacing if interactive { let _ = progress_bars.hide(); diff --git a/crates/goose-cli/src/session/task_execution_display.rs b/crates/goose-cli/src/session/task_execution_display.rs new file mode 100644 index 000000000000..ba57c00dedd6 --- /dev/null +++ b/crates/goose-cli/src/session/task_execution_display.rs @@ -0,0 +1,174 @@ +use serde_json::Value; + +const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H"; +const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H"; +const CLEAR_TO_EOL: &str = "\x1b[K"; +const CLEAR_BELOW: &str = "\x1b[J"; +pub const TASK_EXECUTION_NOTIFICATION_TYPE: &str = "task_execution"; + +pub fn format_tasks_update(data: &Value) -> String { + let mut display = String::new(); + + // Determine if this is initial display or update + static mut INITIAL_SHOWN: bool = false; + unsafe { + if !INITIAL_SHOWN { + display.push_str(CLEAR_SCREEN); + display.push_str("🎯 Task Execution Dashboard\n"); + display.push_str("═══════════════════════════\n\n"); + INITIAL_SHOWN = true; + } else { + display.push_str(MOVE_TO_PROGRESS_LINE); + } + } + + if let Some(stats) = data.get("stats") { + let total = stats.get("total").and_then(|v| v.as_u64()).unwrap_or(0); + let pending = stats.get("pending").and_then(|v| v.as_u64()).unwrap_or(0); + let running = stats.get("running").and_then(|v| v.as_u64()).unwrap_or(0); + let completed = stats.get("completed").and_then(|v| v.as_u64()).unwrap_or(0); + let failed = stats.get("failed").and_then(|v| v.as_u64()).unwrap_or(0); + + display.push_str(&format!( + "📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed", + total, pending, running, completed, failed + )); + display.push_str(&format!("{}\n\n", CLEAR_TO_EOL)); + } + + if let Some(tasks) = data.get("tasks").and_then(|t| t.as_array()) { + for task in tasks { + let id = task.get("id").and_then(|v| v.as_str()).unwrap_or("unknown"); + let status = task + .get("status") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + let task_type = task + .get("task_type") + .and_then(|v| v.as_str()) + .unwrap_or("task"); + + let status_icon = match status { + "Pending" => "⏳", + "Running" => "🏃", + "Completed" => "✅", + "Failed" => "❌", + _ => "◯", + }; + + display.push_str(&format!( + "{} {} ({}): {}\n", + status_icon, id, task_type, status + )); + + if status == "Running" { + if let Some(output) = task.get("current_output").and_then(|v| v.as_str()) { + if !output.trim().is_empty() { + let lines: Vec<&str> = output.lines().collect(); + if lines.len() > 3 { + display.push_str(" ...\n"); + for line in lines.iter().rev().take(3).rev() { + display.push_str(&format!(" {}\n", line)); + } + } else { + for line in lines { + display.push_str(&format!(" {}\n", line)); + } + } + } + } + } + + if status == "Failed" { + if let Some(error) = task.get("error").and_then(|v| v.as_str()) { + display.push_str(&format!(" Error: {}\n", error)); + } + } + } + } + + display.push_str(CLEAR_BELOW); + display +} + +pub fn format_tasks_complete(data: &Value) -> String { + let mut summary = String::new(); + summary.push_str("Execution Complete!\n"); + summary.push_str("═══════════════════════\n"); + + if let Some(stats) = data.get("stats") { + let total = stats.get("total").and_then(|v| v.as_u64()).unwrap_or(0); + let completed = stats.get("completed").and_then(|v| v.as_u64()).unwrap_or(0); + let failed = stats.get("failed").and_then(|v| v.as_u64()).unwrap_or(0); + let success_rate = stats + .get("success_rate") + .and_then(|v| v.as_f64()) + .unwrap_or(0.0); + + summary.push_str(&format!("Total Tasks: {}\n", total)); + summary.push_str(&format!("✅ Completed: {}\n", completed)); + summary.push_str(&format!("❌ Failed: {}\n", failed)); + summary.push_str(&format!("📈 Success Rate: {:.1}%\n", success_rate)); + } + + if let Some(failed_tasks) = data.get("failed_tasks").and_then(|t| t.as_array()) { + if !failed_tasks.is_empty() { + summary.push_str("\n❌ Failed Tasks:\n"); + for task in failed_tasks { + let name = task + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("Unknown"); + summary.push_str(&format!(" • {}\n", name)); + if let Some(error) = task.get("error").and_then(|v| v.as_str()) { + summary.push_str(&format!(" Error: {}\n", error)); + } + } + } + } + + summary.push_str("\n📝 Generating summary...\n"); + summary +} + +pub fn format_task_execution_notification( + data: &Value, +) -> Option<(String, Option, Option)> { + if let Value::Object(o) = data { + if o.get("type").and_then(|t| t.as_str()) == Some(TASK_EXECUTION_NOTIFICATION_TYPE) { + return Some(match o.get("subtype").and_then(|t| t.as_str()) { + Some("line_output") => { + if let Some(Value::String(line_output)) = o.get("output") { + ( + format!("{}\n", line_output), + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ) + } else { + (data.to_string(), None, None) + } + } + Some("tasks_update") => { + let data_value = Value::Object(o.clone()); + let formatted_display = format_tasks_update(&data_value); + ( + formatted_display, + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ) + } + Some("tasks_complete") => { + let data_value = Value::Object(o.clone()); + let formatted_summary = format_tasks_complete(&data_value); + ( + formatted_summary, + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ) + } + _ => (data.to_string(), None, None), + }); + } + } + None +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs index 3e1974323255..b1dbecc84b03 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs @@ -4,10 +4,12 @@ use std::sync::Arc; use tokio::sync::mpsc; use tokio::time::Instant; -use crate::agents::sub_recipe_execution_tool::dashboard::{DisplayMode, TaskDashboard}; use crate::agents::sub_recipe_execution_tool::lib::{ Config, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; +use crate::agents::sub_recipe_execution_tool::task_execution_tracker::{ + DisplayMode, TaskExecutionTracker, +}; use crate::agents::sub_recipe_execution_tool::tasks::process_task; use crate::agents::sub_recipe_execution_tool::workers::spawn_worker; @@ -19,12 +21,12 @@ pub async fn execute_single_task( notifier: mpsc::Sender, ) -> ExecutionResponse { let start_time = Instant::now(); - let dashboard = Arc::new(TaskDashboard::new( + let task_execution_tracker = Arc::new(TaskExecutionTracker::new( vec![task.clone()], DisplayMode::SingleTaskOutput, notifier, )); - let result = process_task(task, config.timeout_seconds, dashboard).await; + let result = process_task(task, config.timeout_seconds, task_execution_tracker).await; let execution_time = start_time.elapsed().as_millis(); let stats = calculate_stats(&[result.clone()], execution_time); @@ -40,9 +42,9 @@ pub async fn execute_tasks_in_parallel( config: Config, notifier: mpsc::Sender, ) -> ExecutionResponse { - let dashboard = Arc::new(TaskDashboard::new( + let task_execution_tracker = Arc::new(TaskExecutionTracker::new( tasks.clone(), - DisplayMode::Dashboard, + DisplayMode::MultipleTasksOutput, notifier, )); let start_time = Instant::now(); @@ -52,7 +54,7 @@ pub async fn execute_tasks_in_parallel( return create_empty_response(); } - dashboard.refresh_display().await; + task_execution_tracker.refresh_display().await; let (task_tx, task_rx, result_tx, mut result_rx) = create_channels(task_count); @@ -61,7 +63,7 @@ pub async fn execute_tasks_in_parallel( return create_error_response(e); } - let shared_state = create_shared_state(task_rx, result_tx, dashboard.clone()); + let shared_state = create_shared_state(task_rx, result_tx, task_execution_tracker.clone()); // Simple static worker allocation - no dynamic scaling needed let worker_count = std::cmp::min(task_count, config.max_workers); @@ -71,7 +73,7 @@ pub async fn execute_tasks_in_parallel( worker_handles.push(handle); } - let results = collect_results(&mut result_rx, dashboard.clone(), task_count).await; + let results = collect_results(&mut result_rx, task_execution_tracker.clone(), task_count).await; for handle in worker_handles { if let Err(e) = handle.await { @@ -79,7 +81,7 @@ pub async fn execute_tasks_in_parallel( } } - dashboard.show_final_summary().await; + task_execution_tracker.send_tasks_complete().await; let execution_time = start_time.elapsed().as_millis(); let stats = calculate_stats(&results, execution_time); @@ -125,13 +127,13 @@ fn create_channels( fn create_shared_state( task_rx: mpsc::Receiver, result_tx: mpsc::Sender, - dashboard: Arc, + task_execution_tracker: Arc, ) -> Arc { Arc::new(SharedState { task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), result_sender: result_tx, active_workers: Arc::new(AtomicUsize::new(0)), - dashboard, + task_execution_tracker, }) } @@ -163,12 +165,12 @@ fn create_empty_response() -> ExecutionResponse { async fn collect_results( result_rx: &mut mpsc::Receiver, - dashboard: Arc, + task_execution_tracker: Arc, expected_count: usize, ) -> Vec { let mut results = Vec::new(); while let Some(result) = result_rx.recv().await { - dashboard + task_execution_tracker .complete_task(&result.task_id, result.clone()) .await; results.push(result); diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs index 6f131862fefd..03568dc53197 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs @@ -1,7 +1,7 @@ -mod dashboard; mod executor; pub mod lib; pub mod sub_recipe_execute_task_tool; +mod task_execution_tracker; mod tasks; mod types; pub mod utils; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs similarity index 50% rename from crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs rename to crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs index 572c82446073..5844f19640af 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -6,32 +6,24 @@ use tokio::sync::{mpsc, RwLock}; use tokio::time::{sleep, Duration, Instant}; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskResult, TaskStatus}; -use crate::agents::sub_recipe_execution_tool::utils::{ - count_by_status, format_task_display, get_task_name, -}; +use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; #[derive(Debug, Clone, PartialEq)] pub enum DisplayMode { - Dashboard, + MultipleTasksOutput, SingleTaskOutput, } const THROTTLE_INTERVAL_MS: u64 = 1000; -const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H"; -const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H"; -const CLEAR_TO_EOL: &str = "\x1b[K"; -const CLEAR_BELOW: &str = "\x1b[J"; -pub struct TaskDashboard { +pub struct TaskExecutionTracker { tasks: Arc>>, - last_display: Arc>, last_refresh: Arc>, - initial_display_shown: Arc>, notifier: mpsc::Sender, display_mode: DisplayMode, } -impl TaskDashboard { +impl TaskExecutionTracker { pub fn new( tasks: Vec, display_mode: DisplayMode, @@ -57,9 +49,7 @@ impl TaskDashboard { Self { tasks: Arc::new(RwLock::new(task_map)), - last_display: Arc::new(RwLock::new(String::new())), last_refresh: Arc::new(RwLock::new(Instant::now())), - initial_display_shown: Arc::new(RwLock::new(false)), notifier, display_mode, } @@ -89,6 +79,7 @@ impl TaskDashboard { pub async fn send_live_output(&self, task_id: &str, line: &str) { match self.display_mode { DisplayMode::SingleTaskOutput => { + // Send raw output data - let subscriber format it let _ = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { @@ -96,13 +87,15 @@ impl TaskDashboard { method: "notifications/message".to_string(), params: Some(json!({ "data": { - "type": "dashboard", - "display": format!("{}\n", line) + "type": "task_execution", + "subtype": "line_output", + "task_id": task_id, + "output": line } })), })); } - DisplayMode::Dashboard => { + DisplayMode::MultipleTasksOutput => { let mut tasks = self.tasks.write().await; if let Some(task_info) = tasks.get_mut(task_id) { task_info.current_output.push_str(line); @@ -129,72 +122,53 @@ impl TaskDashboard { } } - fn render_header(&self, display: &mut String, initial_shown: &mut bool) { - if !*initial_shown { - display.push_str(CLEAR_SCREEN); - display.push_str("🎯 Task Execution Dashboard\n"); - display.push_str("═══════════════════════════\n\n"); - *initial_shown = true; - } else { - display.push_str(MOVE_TO_PROGRESS_LINE); - } - } - - fn render_progress_line(&self, display: &mut String, tasks: &HashMap) { - let (total, pending, running, completed, failed) = count_by_status(tasks); - display.push_str(&format!( - "📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed", - total, pending, running, completed, failed - )); - display.push_str(&format!("{}\n\n", CLEAR_TO_EOL)); - } - - fn render_task(&self, display: &mut String, task_info: &TaskInfo) { - let task_display = format_task_display(task_info, Instant::now()); - display.push_str(&task_display); - } + async fn send_tasks_update(&self) { + let tasks = self.tasks.read().await; + let task_list: Vec<_> = tasks.values().collect(); + let (total, pending, running, completed, failed) = count_by_status(&tasks); - async fn update_display_if_changed(&self, display: String) { - let mut last_display = self.last_display.write().await; - if *last_display != display { - let _ = self - .notifier - .try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: "notifications/message".to_string(), - params: Some(json!({ - "data": { - "type": "dashboard", - "display": display.clone() - } - })), - })); - *last_display = display; - } + let _ = self + .notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": { + "type": "task_execution", + "subtype": "tasks_update", + "stats": { + "total": total, + "pending": pending, + "running": running, + "completed": completed, + "failed": failed + }, + "tasks": task_list.iter().map(|task_info| { + let now = Instant::now(); + json!({ + "id": task_info.task.id, + "status": task_info.status, + "duration_secs": task_info.start_time.map(|start| { + if let Some(end) = task_info.end_time { + end.duration_since(start).as_secs_f64() + } else { + now.duration_since(start).as_secs_f64() + } + }), + "current_output": task_info.current_output, + "task_type": task_info.task.task_type, + "error": task_info.error() + }) + }).collect::>() + } + })), + })); } pub async fn refresh_display(&self) { match self.display_mode { - DisplayMode::Dashboard => { - let tasks = self.tasks.read().await; - let mut display = String::new(); - - let mut initial_shown = self.initial_display_shown.write().await; - self.render_header(&mut display, &mut initial_shown); - drop(initial_shown); - - self.render_progress_line(&mut display, &tasks); - - let mut task_list: Vec<_> = tasks.values().collect(); - task_list.sort_by_key(|t| &t.task.id); - - for task_info in task_list { - self.render_task(&mut display, task_info); - } - - display.push_str(CLEAR_BELOW); - - self.update_display_if_changed(display).await; + DisplayMode::MultipleTasksOutput => { + self.send_tasks_update().await; } DisplayMode::SingleTaskOutput => { // No dashboard display needed for single task output mode @@ -203,39 +177,23 @@ impl TaskDashboard { } } - pub async fn show_final_summary(&self) { + pub async fn send_tasks_complete(&self) { let tasks = self.tasks.read().await; - - let mut summary = String::new(); - summary.push_str("Execution Complete!\n"); - summary.push_str("═══════════════════════\n"); - let (total, _, _, completed, failed) = count_by_status(&tasks); - summary.push_str(&format!("Total Tasks: {}\n", total)); - summary.push_str(&format!("✅ Completed: {}\n", completed)); - summary.push_str(&format!("❌ Failed: {}\n", failed)); - summary.push_str(&format!( - "📈 Success Rate: {:.1}%\n", - (completed as f64 / total as f64) * 100.0 - )); - - if failed > 0 { - summary.push_str("\n❌ Failed Tasks:\n"); - for task_info in tasks.values() { - if matches!(task_info.status, TaskStatus::Failed) { - let task_name = get_task_name(task_info); - summary.push_str(&format!(" • {}\n", task_name)); - if let Some(error) = task_info.error() { - summary.push_str(&format!(" Error: {}\n", error)); - } - } - } - } - - summary.push_str("\n📝 Generating summary...\n"); + // Send structured summary data only + let failed_tasks: Vec<_> = tasks + .values() + .filter(|task_info| matches!(task_info.status, TaskStatus::Failed)) + .map(|task_info| { + json!({ + "id": task_info.task.id, + "name": get_task_name(task_info), + "error": task_info.error() + }) + }) + .collect(); - // Send the final summary via notification let _ = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { @@ -243,8 +201,15 @@ impl TaskDashboard { method: "notifications/message".to_string(), params: Some(json!({ "data": { - "type": "dashboard", - "display": summary + "type": "task_execution", + "subtype": "tasks_complete", + "stats": { + "total": total, + "completed": completed, + "failed": failed, + "success_rate": (completed as f64 / total as f64) * 100.0 + }, + "failed_tasks": failed_tasks } })), })); diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index 969d885cedea..d59eac7cfe3b 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -6,18 +6,23 @@ use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; use tokio::time::timeout; -use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; +use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult, TaskStatus}; pub async fn process_task( task: &Task, timeout_seconds: u64, - dashboard: Arc, + task_execution_tracker: Arc, ) -> TaskResult { let task_clone = task.clone(); let timeout_duration = Duration::from_secs(timeout_seconds); - match timeout(timeout_duration, get_task_result(task_clone, dashboard)).await { + match timeout( + timeout_duration, + get_task_result(task_clone, task_execution_tracker), + ) + .await + { Ok(Ok(data)) => TaskResult { task_id: task.id.clone(), status: TaskStatus::Completed, @@ -39,10 +44,18 @@ pub async fn process_task( } } -async fn get_task_result(task: Task, dashboard: Arc) -> Result { +async fn get_task_result( + task: Task, + task_execution_tracker: Arc, +) -> Result { let (command, output_identifier) = build_command(&task)?; - let (stdout_output, stderr_output, success) = - run_command(command, &output_identifier, &task.id, dashboard).await?; + let (stdout_output, stderr_output, success) = run_command( + command, + &output_identifier, + &task.id, + task_execution_tracker, + ) + .await?; if success { process_output(stdout_output) @@ -93,7 +106,7 @@ async fn run_command( mut command: Command, output_identifier: &str, task_id: &str, - dashboard: Arc, + task_execution_tracker: Arc, ) -> Result<(String, String, bool), String> { let mut child = command .spawn() @@ -102,10 +115,20 @@ async fn run_command( let stdout = child.stdout.take().expect("Failed to capture stdout"); let stderr = child.stderr.take().expect("Failed to capture stderr"); - let stdout_task = - spawn_output_reader(stdout, output_identifier, false, task_id, dashboard.clone()); - let stderr_task = - spawn_output_reader(stderr, output_identifier, true, task_id, dashboard.clone()); + let stdout_task = spawn_output_reader( + stdout, + output_identifier, + false, + task_id, + task_execution_tracker.clone(), + ); + let stderr_task = spawn_output_reader( + stderr, + output_identifier, + true, + task_id, + task_execution_tracker.clone(), + ); let status = child .wait() @@ -123,7 +146,7 @@ fn spawn_output_reader( output_identifier: &str, is_stderr: bool, task_id: &str, - dashboard: Arc, + task_execution_tracker: Arc, ) -> tokio::task::JoinHandle { let output_identifier = output_identifier.to_string(); let task_id = task_id.to_string(); @@ -136,7 +159,9 @@ fn spawn_output_reader( if !is_stderr { // Use dashboard's smart output handling based on display mode - dashboard.send_live_output(&task_id, &line).await; + task_execution_tracker + .send_live_output(&task_id, &line) + .await; } else { eprintln!("[stderr for {}] {}", output_identifier, line); } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index f93c4f5c0c9a..62f37028efd5 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -4,7 +4,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::sync::mpsc; -use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; +use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Task { @@ -93,7 +93,7 @@ pub struct SharedState { pub task_receiver: Arc>>, pub result_sender: mpsc::Sender, pub active_workers: Arc, - pub dashboard: Arc, + pub task_execution_tracker: Arc, } impl SharedState { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index ea11a7e604cb..9e34aadda826 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -21,8 +21,9 @@ pub fn spawn_worker( async fn worker_loop(state: Arc, _worker_id: usize, timeout_seconds: u64) { while let Some(task) = receive_task(&state).await { - state.dashboard.start_task(&task.id).await; - let result = process_task(&task, timeout_seconds, state.dashboard.clone()).await; + state.task_execution_tracker.start_task(&task.id).await; + let result = + process_task(&task, timeout_seconds, state.task_execution_tracker.clone()).await; if let Err(e) = state.result_sender.send(result).await { eprintln!("Worker failed to send result: {}", e); From d8fdedd9c0e02d12def90963c1f6ad1da2b22eec Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 07:00:20 +1000 Subject: [PATCH 12/43] better console ui and run sub recipe in no-session mode --- .../src/session/task_execution_display.rs | 160 +++++++++++++----- .../task_execution_tracker.rs | 25 +++ .../agents/sub_recipe_execution_tool/tasks.rs | 2 +- 3 files changed, 142 insertions(+), 45 deletions(-) diff --git a/crates/goose-cli/src/session/task_execution_display.rs b/crates/goose-cli/src/session/task_execution_display.rs index ba57c00dedd6..38542ba2c29c 100644 --- a/crates/goose-cli/src/session/task_execution_display.rs +++ b/crates/goose-cli/src/session/task_execution_display.rs @@ -9,7 +9,6 @@ pub const TASK_EXECUTION_NOTIFICATION_TYPE: &str = "task_execution"; pub fn format_tasks_update(data: &Value) -> String { let mut display = String::new(); - // Determine if this is initial display or update static mut INITIAL_SHOWN: bool = false; unsafe { if !INITIAL_SHOWN { @@ -37,58 +36,131 @@ pub fn format_tasks_update(data: &Value) -> String { } if let Some(tasks) = data.get("tasks").and_then(|t| t.as_array()) { - for task in tasks { - let id = task.get("id").and_then(|v| v.as_str()).unwrap_or("unknown"); - let status = task - .get("status") - .and_then(|v| v.as_str()) - .unwrap_or("unknown"); - let task_type = task - .get("task_type") - .and_then(|v| v.as_str()) - .unwrap_or("task"); - - let status_icon = match status { - "Pending" => "⏳", - "Running" => "🏃", - "Completed" => "✅", - "Failed" => "❌", - _ => "◯", - }; - - display.push_str(&format!( - "{} {} ({}): {}\n", - status_icon, id, task_type, status - )); + let mut sorted_tasks: Vec<_> = tasks.iter().collect(); + sorted_tasks.sort_by_key(|task| task.get("id").and_then(|v| v.as_str()).unwrap_or("")); - if status == "Running" { - if let Some(output) = task.get("current_output").and_then(|v| v.as_str()) { - if !output.trim().is_empty() { - let lines: Vec<&str> = output.lines().collect(); - if lines.len() > 3 { - display.push_str(" ...\n"); - for line in lines.iter().rev().take(3).rev() { - display.push_str(&format!(" {}\n", line)); - } - } else { - for line in lines { - display.push_str(&format!(" {}\n", line)); - } - } - } + for task in sorted_tasks { + display.push_str(&format_task_from_json(task)); + } + } + + display.push_str(CLEAR_BELOW); + display +} + +fn format_task_from_json(task: &Value) -> String { + let mut task_display = String::new(); + + let id = task.get("id").and_then(|v| v.as_str()).unwrap_or("unknown"); + let status = task + .get("status") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + let task_type = task + .get("task_type") + .and_then(|v| v.as_str()) + .unwrap_or("task"); + let task_name = task.get("task_name").and_then(|v| v.as_str()).unwrap_or(id); + let task_metadata = task + .get("task_metadata") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + let status_icon = match status { + "Pending" => "⏳", + "Running" => "🏃", + "Completed" => "✅", + "Failed" => "❌", + _ => "◯", + }; + + task_display.push_str(&format!( + "{} {} ({}){}\n", + status_icon, task_name, task_type, CLEAR_TO_EOL + )); + + if !task_metadata.is_empty() { + task_display.push_str(&format!( + " 📋 Parameters: {}{}\n", + task_metadata, CLEAR_TO_EOL + )); + } + + if let Some(duration_secs) = task.get("duration_secs").and_then(|v| v.as_f64()) { + task_display.push_str(&format!(" ⏱️ {:.1}s{}\n", duration_secs, CLEAR_TO_EOL)); + } + + if status == "Running" { + if let Some(current_output) = task.get("current_output").and_then(|v| v.as_str()) { + if !current_output.trim().is_empty() { + let processed_output = process_output_for_display(current_output); + if !processed_output.is_empty() { + task_display.push_str(&format!(" 💬 {}{}\n", processed_output, CLEAR_TO_EOL)); } } + } + } - if status == "Failed" { - if let Some(error) = task.get("error").and_then(|v| v.as_str()) { - display.push_str(&format!(" Error: {}\n", error)); + if status == "Failed" { + if let Some(error) = task.get("error").and_then(|v| v.as_str()) { + let error_preview = truncate_with_ellipsis(error, 80); + task_display.push_str(&format!( + " ⚠️ {}{}\n", + error_preview.replace('\n', " "), + CLEAR_TO_EOL + )); + } + } + + task_display.push_str(&format!("{}\n", CLEAR_TO_EOL)); + task_display +} + +fn process_output_for_display(output: &str) -> String { + const MAX_OUTPUT_LINES: usize = 2; + const OUTPUT_PREVIEW_LENGTH: usize = 100; + + let lines: Vec<&str> = output.lines().collect(); + let recent_lines = if lines.len() > MAX_OUTPUT_LINES { + &lines[lines.len() - MAX_OUTPUT_LINES..] + } else { + &lines + }; + + let clean_output = recent_lines.join(" | "); + let stripped = strip_ansi_codes(&clean_output); + truncate_with_ellipsis(&stripped, OUTPUT_PREVIEW_LENGTH) +} + +fn truncate_with_ellipsis(text: &str, max_len: usize) -> String { + if text.len() > max_len { + format!("{}...", &text[..max_len.saturating_sub(3)]) + } else { + text.to_string() + } +} + +fn strip_ansi_codes(text: &str) -> String { + let mut result = String::new(); + let mut chars = text.chars(); + + while let Some(ch) = chars.next() { + if ch == '\x1b' { + if chars.next() == Some('[') { + loop { + match chars.next() { + Some(c) if c.is_ascii_alphabetic() => break, + Some(_) => continue, + None => break, + } } } + } else { + result.push(ch); } } - display.push_str(CLEAR_BELOW); - display + result } pub fn format_tasks_complete(data: &Value) -> String { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs index 5844f19640af..59325e18bf75 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -7,6 +7,7 @@ use tokio::time::{sleep, Duration, Instant}; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskResult, TaskStatus}; use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; +use serde_json::Value; #[derive(Debug, Clone, PartialEq)] pub enum DisplayMode { @@ -16,6 +17,28 @@ pub enum DisplayMode { const THROTTLE_INTERVAL_MS: u64 = 1000; +fn format_task_metadata(task_info: &TaskInfo) -> String { + if let Some(params) = task_info.task.get_command_parameters() { + if params.is_empty() { + return String::new(); + } + + params + .iter() + .map(|(key, value)| { + let value_str = match value { + Value::String(s) => s.clone(), + _ => value.to_string(), + }; + format!("{}={}", key, value_str) + }) + .collect::>() + .join(",") + } else { + String::new() + } +} + pub struct TaskExecutionTracker { tasks: Arc>>, last_refresh: Arc>, @@ -157,6 +180,8 @@ impl TaskExecutionTracker { }), "current_output": task_info.current_output, "task_type": task_info.task.task_type, + "task_name": get_task_name(task_info), + "task_metadata": format_task_metadata(task_info), "error": task_info.error() }) }).collect::>() diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index d59eac7cfe3b..b975184da343 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -79,7 +79,7 @@ fn build_command(task: &Task) -> Result<(Command, String), String> { output_identifier = format!("sub-recipe {}", sub_recipe_name); let mut cmd = Command::new("goose"); - cmd.arg("run").arg("--recipe").arg(path); + cmd.arg("run").arg("--recipe").arg(path).arg("--no-session"); for (key, value) in command_parameters { let key_str = key.to_string(); From f1c38f93355a6f2a3a12f8963fa1021e51c6aff8 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 09:45:26 +1000 Subject: [PATCH 13/43] added timeout in task execution --- .../src/agents/recipe_tools/sub_recipe_tools.rs | 4 ++++ .../agents/sub_recipe_execution_tool/executor.rs | 5 ++--- .../src/agents/sub_recipe_execution_tool/lib.rs | 2 +- .../sub_recipe_execute_task_tool.rs | 11 +++++------ .../src/agents/sub_recipe_execution_tool/tasks.rs | 4 ++-- .../src/agents/sub_recipe_execution_tool/types.rs | 8 ++------ .../src/agents/sub_recipe_execution_tool/workers.rs | 13 ++++--------- crates/goose/src/recipe/mod.rs | 1 + 8 files changed, 21 insertions(+), 27 deletions(-) diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 8cd1a8599037..116e4a0a5193 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -57,6 +57,10 @@ pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Re Task { id: uuid::Uuid::new_v4().to_string(), task_type: "sub_recipe".to_string(), + timeout_in_seconds: sub_recipe + .executions + .as_ref() + .and_then(|e| e.timeout_in_seconds), payload, } }) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs index b1dbecc84b03..4ec7443a40b4 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs @@ -17,7 +17,6 @@ const EXECUTION_STATUS_COMPLETED: &str = "completed"; pub async fn execute_single_task( task: &Task, - config: Config, notifier: mpsc::Sender, ) -> ExecutionResponse { let start_time = Instant::now(); @@ -26,7 +25,7 @@ pub async fn execute_single_task( DisplayMode::SingleTaskOutput, notifier, )); - let result = process_task(task, config.timeout_seconds, task_execution_tracker).await; + let result = process_task(task, task_execution_tracker).await; let execution_time = start_time.elapsed().as_millis(); let stats = calculate_stats(&[result.clone()], execution_time); @@ -69,7 +68,7 @@ pub async fn execute_tasks_in_parallel( let worker_count = std::cmp::min(task_count, config.max_workers); let mut worker_handles = Vec::new(); for i in 0..worker_count { - let handle = spawn_worker(shared_state.clone(), i, config.timeout_seconds); + let handle = spawn_worker(shared_state.clone(), i); worker_handles.push(handle); } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs index 5973025a1786..6a18b24907a8 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs @@ -29,7 +29,7 @@ pub async fn execute_tasks( match execution_mode { "sequential" => { if task_count == 1 { - let response = execute_single_task(&tasks[0], config, notifier).await; + let response = execute_single_task(&tasks[0], notifier).await; serde_json::to_value(response) .map_err(|e| format!("Failed to serialize response: {}", e)) } else { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 870350a55da1..40fe117a05ba 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -62,6 +62,10 @@ Pre-created Task Based: "default": "sub_recipe", "description": "the type of task to execute, can be one of: sub_recipe, text_instruction" }, + "timeout_in_seconds": { + "type": "number", + "description": "timeout in seconds for the task." + }, "payload": { "type": "object", "properties": { @@ -97,9 +101,6 @@ Pre-created Task Based: "config": { "type": "object", "properties": { - "timeout_seconds": { - "type": "number" - }, "max_workers": { "type": "number" }, @@ -126,11 +127,9 @@ pub async fn run_tasks(execute_data: Value) -> ToolCallResult { let result_future = async move { let execute_data_clone = execute_data.clone(); - let default_execution_mode_value = Value::String("sequential".to_string()); let execution_mode = execute_data_clone .get("execution_mode") - .unwrap_or(&default_execution_mode_value) - .as_str() + .and_then(|v| v.as_str()) .unwrap_or("sequential"); match execute_tasks(execute_data, execution_mode, notification_tx).await { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index b975184da343..f5762f973855 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -11,11 +11,11 @@ use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult, TaskStat pub async fn process_task( task: &Task, - timeout_seconds: u64, task_execution_tracker: Arc, ) -> TaskResult { + let timeout_in_seconds = task.timeout_in_seconds.unwrap_or(300); let task_clone = task.clone(); - let timeout_duration = Duration::from_secs(timeout_seconds); + let timeout_duration = Duration::from_secs(timeout_in_seconds); match timeout( timeout_duration, diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index 62f37028efd5..4184eb65ec36 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -10,6 +10,7 @@ use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecut pub struct Task { pub id: String, pub task_type: String, + pub timeout_in_seconds: Option, pub payload: Value, } @@ -110,8 +111,6 @@ impl SharedState { pub struct Config { #[serde(default = "default_max_workers")] pub max_workers: usize, - #[serde(default = "default_timeout")] - pub timeout_seconds: u64, #[serde(default = "default_initial_workers")] pub initial_workers: usize, } @@ -120,7 +119,6 @@ impl Default for Config { fn default() -> Self { Self { max_workers: default_max_workers(), - timeout_seconds: default_timeout(), initial_workers: default_initial_workers(), } } @@ -129,9 +127,7 @@ impl Default for Config { fn default_max_workers() -> usize { 10 } -fn default_timeout() -> u64 { - 300 -} + fn default_initial_workers() -> usize { 2 } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index 9e34aadda826..f891595456b5 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -7,23 +7,18 @@ async fn receive_task(state: &SharedState) -> Option { receiver.recv().await } -pub fn spawn_worker( - state: Arc, - worker_id: usize, - timeout_seconds: u64, -) -> tokio::task::JoinHandle<()> { +pub fn spawn_worker(state: Arc, worker_id: usize) -> tokio::task::JoinHandle<()> { state.increment_active_workers(); tokio::spawn(async move { - worker_loop(state, worker_id, timeout_seconds).await; + worker_loop(state, worker_id).await; }) } -async fn worker_loop(state: Arc, _worker_id: usize, timeout_seconds: u64) { +async fn worker_loop(state: Arc, _worker_id: usize) { while let Some(task) = receive_task(&state).await { state.task_execution_tracker.start_task(&task.id).await; - let result = - process_task(&task, timeout_seconds, state.task_execution_tracker.clone()).await; + let result = process_task(&task, state.task_execution_tracker.clone()).await; if let Err(e) = state.result_sender.send(result).await { eprintln!("Worker failed to send result: {}", e); diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 75c08a42fa16..3e4727c4092e 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -144,6 +144,7 @@ pub struct SubRecipe { pub struct Execution { #[serde(default)] pub parallel: bool, + pub timeout_in_seconds: Option, #[serde(skip_serializing_if = "Option::is_none")] pub runs: Option>, } From 54d5224e70336feb24f545a43836874c9eb9d8aa Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 10:52:44 +1000 Subject: [PATCH 14/43] cleaned up some utils and added the error tool response --- .../goose-cli/src/recipes/extract_from_cli.rs | 1 + .../agents/recipe_tools/param_utils/tests.rs | 1 + .../agents/recipe_tools/sub_recipe_tools.rs | 5 +- .../recipe_tools/sub_recipe_tools/tests.rs | 1 + .../agents/sub_recipe_execution_tool/lib.rs | 38 +- .../sub_recipe_execute_task_tool.rs | 1 - .../sub_recipe_execution_tool/utils/mod.rs | 167 ------- .../sub_recipe_execution_tool/utils/tests.rs | 453 +----------------- crates/goose/src/recipe/mod.rs | 2 +- 9 files changed, 47 insertions(+), 622 deletions(-) diff --git a/crates/goose-cli/src/recipes/extract_from_cli.rs b/crates/goose-cli/src/recipes/extract_from_cli.rs index 84d578c0cae4..8ba8658c8cb6 100644 --- a/crates/goose-cli/src/recipes/extract_from_cli.rs +++ b/crates/goose-cli/src/recipes/extract_from_cli.rs @@ -31,6 +31,7 @@ pub fn extract_recipe_info_from_cli( let additional_sub_recipe = SubRecipe { path: recipe_file_path.to_string_lossy().to_string(), name, + timeout_in_seconds: None, values: None, executions: None, }; diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs index 81e5c412a8c7..e7b3f6878eb9 100644 --- a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -12,6 +12,7 @@ mod tests { let sub_recipe = SubRecipe { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), + timeout_in_seconds: None, values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), executions: None, }; diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 116e4a0a5193..88b100c5e67f 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -57,10 +57,7 @@ pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Re Task { id: uuid::Uuid::new_v4().to_string(), task_type: "sub_recipe".to_string(), - timeout_in_seconds: sub_recipe - .executions - .as_ref() - .and_then(|e| e.timeout_in_seconds), + timeout_in_seconds: sub_recipe.timeout_in_seconds, payload, } }) diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index a4956f65edda..ca66e97819bb 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -11,6 +11,7 @@ mod tests { let sub_recipe = SubRecipe { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), + timeout_in_seconds: None, values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), executions: None, }; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs index 6a18b24907a8..f51ceea71135 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs @@ -30,17 +30,45 @@ pub async fn execute_tasks( "sequential" => { if task_count == 1 { let response = execute_single_task(&tasks[0], notifier).await; - serde_json::to_value(response) - .map_err(|e| format!("Failed to serialize response: {}", e)) + handle_response(response) } else { Err("Sequential execution mode requires exactly one task".to_string()) } } "parallel" => { - let response = execute_tasks_in_parallel(tasks, config, notifier).await; - serde_json::to_value(response) - .map_err(|e| format!("Failed to serialize response: {}", e)) + let response: ExecutionResponse = execute_tasks_in_parallel(tasks, config, notifier).await; + handle_response(response) } _ => Err("Invalid execution mode".to_string()), } } + +fn handle_response(response: ExecutionResponse) -> Result { + if response.stats.failed > 0 { + let failed_tasks: Vec = response.results + .iter() + .filter(|r| matches!(r.status, TaskStatus::Failed)) + .map(|r| { + let error_msg = r.error.as_ref().map(|s| s.as_str()).unwrap_or("Unknown error"); + format!("Task '{}' ({}): {}", r.task_id, get_task_description(r), error_msg) + }) + .collect(); + + let error_summary = format!( + "{}/{} tasks failed:\n{}", + response.stats.failed, + response.stats.total_tasks, + failed_tasks.join("\n") + ); + + return Err(error_summary); + } + serde_json::to_value(response) + .map_err(|e| format!("Failed to serialize response: {}", e)) +} + +fn get_task_description(result: &TaskResult) -> String { + // We'd need to reconstruct task info from the result or pass it through + // For now, just use the task_id as placeholder + format!("ID: {}", result.task_id) +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 40fe117a05ba..f1ca709cb318 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -13,7 +13,6 @@ pub fn create_sub_recipe_execute_task_tool() -> Tool { Tool::new( SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, "Only use this tool when you execute sub recipe task. - EXECUTION STRATEGY DECISION: 1. PRE-CREATED TASKS: If tasks were created by subrecipe__create_task_* tools, check the execution_mode in the response: - If execution_mode is 'parallel', use parallel execution diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index 7b684126c383..bf806a50d182 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -1,15 +1,8 @@ use serde_json::{Map, Value}; use std::collections::HashMap; -use tokio::time::Instant; use crate::agents::sub_recipe_execution_tool::types::{TaskInfo, TaskStatus}; -// Constants for display formatting -const MAX_OUTPUT_LINES: usize = 2; -const OUTPUT_PREVIEW_LENGTH: usize = 100; -const ERROR_PREVIEW_LENGTH: usize = 80; -const CLEAR_TO_EOL: &str = "\x1b[K"; - pub fn get_task_name(task_info: &TaskInfo) -> &str { task_info .task @@ -21,14 +14,6 @@ pub fn get_command_parameters(task_info: &TaskInfo) -> Option<&Map String { - if text.len() > max_len { - format!("{}...", &text[..max_len.saturating_sub(3)]) - } else { - text.to_string() - } -} - pub fn count_by_status(tasks: &HashMap) -> (usize, usize, usize, usize, usize) { let total = tasks.len(); let (pending, running, completed, failed) = tasks.values().fold( @@ -43,157 +28,5 @@ pub fn count_by_status(tasks: &HashMap) -> (usize, usize, usiz (total, pending, running, completed, failed) } -pub fn strip_ansi_codes(text: &str) -> String { - let mut result = String::new(); - let mut chars = text.chars(); - - while let Some(ch) = chars.next() { - if ch == '\x1b' { - if chars.next() == Some('[') { - loop { - match chars.next() { - Some(c) if c.is_ascii_alphabetic() => break, - Some(_) => continue, - None => break, - } - } - } - } else { - result.push(ch); - } - } - - result -} - -pub fn get_status_icon(status: &TaskStatus) -> &'static str { - match status { - TaskStatus::Pending => "⏳", - TaskStatus::Running => "🏃", - TaskStatus::Completed => "✅", - TaskStatus::Failed => "❌", - } -} - -pub fn process_output_lines(output: &str) -> String { - let lines: Vec<&str> = output.lines().collect(); - let recent_lines = if lines.len() > MAX_OUTPUT_LINES { - &lines[lines.len() - MAX_OUTPUT_LINES..] - } else { - &lines - }; - - let clean_output = recent_lines.join("\n"); - strip_ansi_codes(&clean_output) -} - -/// Format task timing information -pub fn format_task_timing(task_info: &TaskInfo, current_time: Instant) -> Option { - task_info.start_time.map(|start_time| { - let duration = if let Some(end_time) = task_info.end_time { - end_time.duration_since(start_time) - } else { - current_time.duration_since(start_time) - }; - format!( - " ⏱️ {:.1}s{} -", - duration.as_secs_f64(), - CLEAR_TO_EOL - ) - }) -} - -/// Format task output preview -pub fn format_task_output(task_info: &TaskInfo) -> Option { - if matches!(task_info.status, TaskStatus::Running) && !task_info.current_output.is_empty() { - let output_preview = - truncate_with_ellipsis(&task_info.current_output, OUTPUT_PREVIEW_LENGTH); - Some(format!( - " 💬 {}{} -", - output_preview.replace('\n', " | "), - CLEAR_TO_EOL - )) - } else { - None - } -} - -/// Format task error information -pub fn format_task_error(task_info: &TaskInfo) -> Option { - task_info.error().map(|error| { - let error_preview = truncate_with_ellipsis(error, ERROR_PREVIEW_LENGTH); - format!( - " ⚠️ {}{} -", - error_preview.replace('\n', " "), - CLEAR_TO_EOL - ) - }) -} - -pub fn format_command_parameters(task_info: &TaskInfo) -> Option { - get_command_parameters(task_info).map(|params| { - if params.is_empty() { - return format!(" 📋 Parameters: (none){}\n", CLEAR_TO_EOL); - } - - let params_str = params - .iter() - .map(|(key, value)| { - let value_str = match value { - Value::String(s) => s.clone(), - _ => value.to_string(), - }; - format!("{}={}", key, value_str) - }) - .collect::>() - .join(", "); - - let params_preview = truncate_with_ellipsis(¶ms_str, OUTPUT_PREVIEW_LENGTH); - format!(" 📋 Parameters: {}{}\n", params_preview, CLEAR_TO_EOL) - }) -} - -pub fn format_task_display(task_info: &TaskInfo, current_time: Instant) -> String { - let mut display = String::new(); - - let status_icon = get_status_icon(&task_info.status); - let task_name = get_task_name(task_info); - - // Task status line - display.push_str(&format!( - "{} {} ({}){} -", - status_icon, task_name, task_info.task.task_type, CLEAR_TO_EOL - )); - - if let Some(params) = format_command_parameters(task_info) { - display.push_str(¶ms); - } - - if let Some(timing) = format_task_timing(task_info, current_time) { - display.push_str(&timing); - } - - if let Some(output) = format_task_output(task_info) { - display.push_str(&output); - } - - // Task error (if failed) - if let Some(error) = format_task_error(task_info) { - display.push_str(&error); - } - - display.push_str(&format!( - "{} -", - CLEAR_TO_EOL - )); - - display -} - #[cfg(test)] mod tests; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs index 88d823092c34..f921f131781f 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -1,101 +1,20 @@ use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskStatus}; use crate::agents::sub_recipe_execution_tool::utils::{ - count_by_status, format_task_display, format_task_error, format_task_output, - format_task_timing, get_status_icon, get_task_name, process_output_lines, strip_ansi_codes, - truncate_with_ellipsis, + count_by_status, + format_task_timing, get_status_icon, get_task_name, }; use serde_json::json; use std::collections::HashMap; -use tokio::time::Instant; -mod truncate_with_ellipsis { +mod test_get_task_name { use super::*; #[test] - fn returns_original_when_under_limit() { - assert_eq!(truncate_with_ellipsis("hello", 10), "hello"); - assert_eq!(truncate_with_ellipsis("hi", 5), "hi"); - } - - #[test] - fn truncates_when_over_limit() { - assert_eq!(truncate_with_ellipsis("hello world", 5), "he..."); - assert_eq!( - truncate_with_ellipsis("very long text here", 10), - "very lo..." - ); - } - - #[test] - fn handles_empty_string() { - assert_eq!(truncate_with_ellipsis("", 5), ""); - } - - #[test] - fn handles_exact_limit() { - assert_eq!(truncate_with_ellipsis("hello", 5), "hello"); - } - - #[test] - fn handles_very_short_limit() { - assert_eq!(truncate_with_ellipsis("hello", 3), "..."); - assert_eq!(truncate_with_ellipsis("hi", 2), "hi"); // Under limit, return as-is - } -} - -mod strip_ansi_codes { - use super::*; - - #[test] - fn preserves_plain_text() { - assert_eq!(strip_ansi_codes("hello world"), "hello world"); - assert_eq!(strip_ansi_codes("no ansi codes"), "no ansi codes"); - } - - #[test] - fn removes_color_codes() { - assert_eq!(strip_ansi_codes("\x1b[31mred text\x1b[0m"), "red text"); - assert_eq!(strip_ansi_codes("\x1b[32mgreen\x1b[0m"), "green"); - } - - #[test] - fn removes_complex_formatting() { - assert_eq!( - strip_ansi_codes("\x1b[1;32mbold green\x1b[0m"), - "bold green" - ); - assert_eq!( - strip_ansi_codes("\x1b[4;31munderline red\x1b[0m"), - "underline red" - ); - } - - #[test] - fn handles_multiple_sequences() { - let input = "\x1b[31mred\x1b[0m normal \x1b[32mgreen\x1b[0m"; - assert_eq!(strip_ansi_codes(input), "red normal green"); - } - - #[test] - fn handles_empty_string() { - assert_eq!(strip_ansi_codes(""), ""); - } - - #[test] - fn handles_malformed_sequences() { - // Incomplete escape sequence - our function strips the \x1b char - assert_eq!(strip_ansi_codes("\x1b hello"), "hello"); - } -} - -mod get_task_name { - use super::*; - - #[test] - fn extracts_sub_recipe_name() { + fn test_extracts_sub_recipe_name() { let sub_recipe_task = Task { id: "task_1".to_string(), task_type: "sub_recipe".to_string(), + timeout_in_seconds: None, payload: json!({ "sub_recipe": { "name": "my_recipe", @@ -120,6 +39,7 @@ mod get_task_name { fn falls_back_to_task_id_for_text_instruction() { let text_task = Task { id: "task_2".to_string(), + timeout_in_seconds: None, task_type: "text_instruction".to_string(), payload: json!({"text_instruction": "do something"}), }; @@ -141,6 +61,7 @@ mod get_task_name { let malformed_task = Task { id: "task_3".to_string(), task_type: "sub_recipe".to_string(), + timeout_in_seconds: None, payload: json!({ "sub_recipe": { "recipe_path": "/path/to/recipe" @@ -166,6 +87,7 @@ mod get_task_name { let malformed_task = Task { id: "task_4".to_string(), task_type: "sub_recipe".to_string(), + timeout_in_seconds: None, payload: json!({}), // missing "sub_recipe" field }; @@ -190,6 +112,7 @@ mod count_by_status { task: Task { id: id.to_string(), task_type: "test".to_string(), + timeout_in_seconds: None, payload: json!({}), }, status, @@ -260,361 +183,3 @@ mod count_by_status { ); } } - -mod get_status_icon { - use super::*; - - #[test] - fn returns_correct_icon_for_pending() { - assert_eq!(get_status_icon(&TaskStatus::Pending), "⏳"); - } - - #[test] - fn returns_correct_icon_for_running() { - assert_eq!(get_status_icon(&TaskStatus::Running), "🏃"); - } - - #[test] - fn returns_correct_icon_for_completed() { - assert_eq!(get_status_icon(&TaskStatus::Completed), "✅"); - } - - #[test] - fn returns_correct_icon_for_failed() { - assert_eq!(get_status_icon(&TaskStatus::Failed), "❌"); - } -} - -mod process_output_lines { - use super::*; - - #[test] - fn preserves_short_output() { - let output = "line 1\nline 2"; - assert_eq!(process_output_lines(output), "line 1\nline 2"); - } - - #[test] - fn keeps_only_recent_lines_when_too_many() { - let output = "line 1\nline 2\nline 3\nline 4\nline 5"; - let result = process_output_lines(output); - assert_eq!(result, "line 4\nline 5"); - } - - #[test] - fn strips_ansi_codes_from_output() { - let output = "\x1b[31mred line 1\x1b[0m\n\x1b[32mgreen line 2\x1b[0m"; - let result = process_output_lines(output); - assert_eq!(result, "red line 1\ngreen line 2"); - } - - #[test] - fn handles_empty_output() { - assert_eq!(process_output_lines(""), ""); - } - - #[test] - fn handles_single_line() { - assert_eq!(process_output_lines("single line"), "single line"); - } - - #[test] - fn combines_ansi_stripping_and_line_limiting() { - let output = "\x1b[31mline 1\x1b[0m\n\x1b[32mline 2\x1b[0m\n\x1b[33mline 3\x1b[0m\n\x1b[34mline 4\x1b[0m"; - let result = process_output_lines(output); - assert_eq!(result, "line 3\nline 4"); - } -} - -mod format_task_timing { - use super::*; - use std::time::Duration; - - fn create_test_task_info_with_timing(start: Option, end: Option) -> TaskInfo { - TaskInfo { - task: Task { - id: "test_task".to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status: TaskStatus::Running, - start_time: start, - end_time: end, - result: None, - current_output: String::new(), - } - } - - #[test] - fn returns_none_when_no_start_time() { - let task_info = create_test_task_info_with_timing(None, None); - let current_time = Instant::now(); - assert!(format_task_timing(&task_info, current_time).is_none()); - } - - #[test] - fn formats_running_task_duration() { - let start_time = Instant::now(); - let current_time = start_time + Duration::from_millis(1500); - let task_info = create_test_task_info_with_timing(Some(start_time), None); - - let result = format_task_timing(&task_info, current_time).unwrap(); - assert!(result.contains("1.5s")); - assert!(result.contains("⏱️")); - } - - #[test] - fn formats_completed_task_duration() { - let start_time = Instant::now(); - let end_time = start_time + Duration::from_millis(2500); - let current_time = Instant::now(); // This shouldn't matter for completed tasks - let task_info = create_test_task_info_with_timing(Some(start_time), Some(end_time)); - - let result = format_task_timing(&task_info, current_time).unwrap(); - assert!(result.contains("2.5s")); - assert!(result.contains("⏱️")); - } -} - -mod format_task_output { - use super::*; - - fn create_test_task_info_with_output(status: TaskStatus, output: &str) -> TaskInfo { - TaskInfo { - task: Task { - id: "test_task".to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status, - start_time: None, - end_time: None, - result: None, - current_output: output.to_string(), - } - } - - #[test] - fn returns_none_for_non_running_tasks() { - let task_info = create_test_task_info_with_output(TaskStatus::Pending, "some output"); - assert!(format_task_output(&task_info).is_none()); - - let task_info = create_test_task_info_with_output(TaskStatus::Completed, "some output"); - assert!(format_task_output(&task_info).is_none()); - - let task_info = create_test_task_info_with_output(TaskStatus::Failed, "some output"); - assert!(format_task_output(&task_info).is_none()); - } - - #[test] - fn returns_none_for_running_task_with_empty_output() { - let task_info = create_test_task_info_with_output(TaskStatus::Running, ""); - assert!(format_task_output(&task_info).is_none()); - } - - #[test] - fn formats_running_task_with_output() { - let task_info = - create_test_task_info_with_output(TaskStatus::Running, "Building project..."); - let result = format_task_output(&task_info).unwrap(); - - assert!(result.contains("💬")); - assert!(result.contains("Building project...")); - } - - #[test] - fn replaces_newlines_with_pipes() { - let task_info = - create_test_task_info_with_output(TaskStatus::Running, "line 1\nline 2\nline 3"); - let result = format_task_output(&task_info).unwrap(); - - assert!(result.contains("line 1 | line 2 | line 3")); - } - - #[test] - fn truncates_long_output() { - let long_output = "a".repeat(150); - let task_info = create_test_task_info_with_output(TaskStatus::Running, &long_output); - let result = format_task_output(&task_info).unwrap(); - - assert!(result.contains("...")); - assert!(result.len() < long_output.len() + 20); // Account for formatting - } -} - -mod format_task_error { - use super::*; - use crate::agents::sub_recipe_execution_tool::types::{TaskResult, TaskStatus}; - - fn create_test_task_info_with_error(error_msg: Option<&str>) -> TaskInfo { - let result = error_msg.map(|msg| TaskResult { - task_id: "test_task".to_string(), - status: TaskStatus::Failed, - data: None, - error: Some(msg.to_string()), - }); - - TaskInfo { - task: Task { - id: "test_task".to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status: TaskStatus::Failed, - start_time: None, - end_time: None, - result, - current_output: String::new(), - } - } - - #[test] - fn returns_none_when_no_error() { - let task_info = create_test_task_info_with_error(None); - assert!(format_task_error(&task_info).is_none()); - } - - #[test] - fn formats_error_message() { - let task_info = create_test_task_info_with_error(Some("File not found")); - let result = format_task_error(&task_info).unwrap(); - - assert!(result.contains("⚠️")); - assert!(result.contains("File not found")); - } - - #[test] - fn replaces_newlines_in_error() { - let task_info = create_test_task_info_with_error(Some("Error on line 1\nError on line 2")); - let result = format_task_error(&task_info).unwrap(); - - assert!(result.contains("Error on line 1 Error on line 2")); - } - - #[test] - fn truncates_long_error() { - let long_error = "error ".repeat(30); - let task_info = create_test_task_info_with_error(Some(&long_error)); - let result = format_task_error(&task_info).unwrap(); - - assert!(result.contains("...")); - assert!(result.len() < long_error.len() + 20); // Account for formatting - } -} - -mod format_task_display { - use super::*; - use std::time::Duration; - - fn create_comprehensive_task_info( - task_name: &str, - status: TaskStatus, - start_time: Option, - end_time: Option, - current_output: &str, - error: Option<&str>, - ) -> TaskInfo { - let result = error.map( - |msg| crate::agents::sub_recipe_execution_tool::types::TaskResult { - task_id: task_name.to_string(), - status: status.clone(), - data: None, - error: Some(msg.to_string()), - }, - ); - - TaskInfo { - task: Task { - id: task_name.to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status, - start_time, - end_time, - result, - current_output: current_output.to_string(), - } - } - - #[test] - fn formats_pending_task() { - let task_info = create_comprehensive_task_info( - "pending_task", - TaskStatus::Pending, - None, - None, - "", - None, - ); - let current_time = Instant::now(); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("⏳")); - assert!(result.contains("pending_task")); - assert!(result.contains("(test)")); - } - - #[test] - fn formats_running_task_with_output() { - let start_time = Instant::now(); - let current_time = start_time + Duration::from_secs(2); - let task_info = create_comprehensive_task_info( - "running_task", - TaskStatus::Running, - Some(start_time), - None, - "Compiling...", - None, - ); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("🏃")); - assert!(result.contains("running_task")); - assert!(result.contains("2.0s")); - assert!(result.contains("💬")); - assert!(result.contains("Compiling...")); - } - - #[test] - fn formats_failed_task_with_error() { - let start_time = Instant::now(); - let end_time = start_time + Duration::from_millis(1500); - let task_info = create_comprehensive_task_info( - "failed_task", - TaskStatus::Failed, - Some(start_time), - Some(end_time), - "", - Some("Compilation failed"), - ); - let current_time = Instant::now(); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("❌")); - assert!(result.contains("failed_task")); - assert!(result.contains("1.5s")); - assert!(result.contains("⚠️")); - assert!(result.contains("Compilation failed")); - } - - #[test] - fn formats_completed_task() { - let start_time = Instant::now(); - let end_time = start_time + Duration::from_secs(3); - let task_info = create_comprehensive_task_info( - "completed_task", - TaskStatus::Completed, - Some(start_time), - Some(end_time), - "", - None, - ); - let current_time = Instant::now(); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("✅")); - assert!(result.contains("completed_task")); - assert!(result.contains("3.0s")); - } -} diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 3e4727c4092e..fe4e514d02a8 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -135,6 +135,7 @@ pub struct Response { pub struct SubRecipe { pub name: String, pub path: String, + pub timeout_in_seconds: Option, #[serde(default, deserialize_with = "deserialize_value_map_as_string")] pub values: Option>, #[serde(skip_serializing_if = "Option::is_none")] @@ -144,7 +145,6 @@ pub struct SubRecipe { pub struct Execution { #[serde(default)] pub parallel: bool, - pub timeout_in_seconds: Option, #[serde(skip_serializing_if = "Option::is_none")] pub runs: Option>, } From 47e769ec0f8121703cb704303b7e794a6e174dd5 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 13:25:19 +1000 Subject: [PATCH 15/43] return task output if task times out --- .../agents/sub_recipe_execution_tool/lib.rs | 31 ++++++++++++++----- .../task_execution_tracker.rs | 7 +++++ .../agents/sub_recipe_execution_tool/tasks.rs | 23 +++++++++----- .../sub_recipe_execution_tool/utils/tests.rs | 3 +- 4 files changed, 47 insertions(+), 17 deletions(-) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs index f51ceea71135..d3b1f51bb517 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs @@ -36,7 +36,8 @@ pub async fn execute_tasks( } } "parallel" => { - let response: ExecutionResponse = execute_tasks_in_parallel(tasks, config, notifier).await; + let response: ExecutionResponse = + execute_tasks_in_parallel(tasks, config, notifier).await; handle_response(response) } _ => Err("Invalid execution mode".to_string()), @@ -45,26 +46,40 @@ pub async fn execute_tasks( fn handle_response(response: ExecutionResponse) -> Result { if response.stats.failed > 0 { - let failed_tasks: Vec = response.results + let failed_tasks: Vec = response + .results .iter() .filter(|r| matches!(r.status, TaskStatus::Failed)) .map(|r| { - let error_msg = r.error.as_ref().map(|s| s.as_str()).unwrap_or("Unknown error"); - format!("Task '{}' ({}): {}", r.task_id, get_task_description(r), error_msg) + let error_msg = r.error.as_deref().unwrap_or("Unknown error"); + let partial_output = r + .data + .as_ref() + .and_then(|d| d.get("partial_output")) + .and_then(|v| v.as_str()) + .filter(|s| !s.trim().is_empty()) + .unwrap_or("No output captured"); + + format!( + "Task '{}' ({}): {}, \noutput: {}", + r.task_id, + get_task_description(r), + error_msg, + partial_output + ) }) .collect(); - + let error_summary = format!( "{}/{} tasks failed:\n{}", response.stats.failed, response.stats.total_tasks, failed_tasks.join("\n") ); - + return Err(error_summary); } - serde_json::to_value(response) - .map_err(|e| format!("Failed to serialize response: {}", e)) + serde_json::to_value(response).map_err(|e| format!("Failed to serialize response: {}", e)) } fn get_task_description(result: &TaskResult) -> String { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs index 59325e18bf75..9d18f89ba7fa 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -99,6 +99,13 @@ impl TaskExecutionTracker { self.refresh_display().await; } + pub async fn get_current_output(&self, task_id: &str) -> Option { + let tasks = self.tasks.read().await; + tasks + .get(task_id) + .map(|task_info| task_info.current_output.clone()) + } + pub async fn send_live_output(&self, task_id: &str, line: &str) { match self.display_mode { DisplayMode::SingleTaskOutput => { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index f5762f973855..428f7bc81f8a 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -17,6 +17,7 @@ pub async fn process_task( let task_clone = task.clone(); let timeout_duration = Duration::from_secs(timeout_in_seconds); + let task_execution_tracker_clone = task_execution_tracker.clone(); match timeout( timeout_duration, get_task_result(task_clone, task_execution_tracker), @@ -35,12 +36,21 @@ pub async fn process_task( data: None, error: Some(error), }, - Err(_) => TaskResult { - task_id: task.id.clone(), - status: TaskStatus::Failed, - data: None, - error: Some("Task timeout".to_string()), - }, + Err(_) => { + let current_output = task_execution_tracker_clone + .get_current_output(&task.id) + .await + .unwrap_or_default(); + + TaskResult { + task_id: task.id.clone(), + status: TaskStatus::Failed, + data: Some(serde_json::json!({ + "partial_output": current_output + })), + error: Some(format!("Task timed out after {}s", timeout_in_seconds)), + } + } } } @@ -158,7 +168,6 @@ fn spawn_output_reader( buffer.push('\n'); if !is_stderr { - // Use dashboard's smart output handling based on display mode task_execution_tracker .send_live_output(&task_id, &line) .await; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs index f921f131781f..e05e00d3ee13 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -1,7 +1,6 @@ use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskStatus}; use crate::agents::sub_recipe_execution_tool::utils::{ - count_by_status, - format_task_timing, get_status_icon, get_task_name, + count_by_status, format_task_timing, get_status_icon, get_task_name, }; use serde_json::json; use std::collections::HashMap; From 85d5286e3f61a25de503f27db8bd402c3fd44ab8 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 14:05:02 +1000 Subject: [PATCH 16/43] changed console output a bit --- .../src/session/task_execution_display.rs | 30 ++++++++++++++++++- .../task_execution_tracker.rs | 26 +++++++++++++--- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/crates/goose-cli/src/session/task_execution_display.rs b/crates/goose-cli/src/session/task_execution_display.rs index 38542ba2c29c..38a168241ac9 100644 --- a/crates/goose-cli/src/session/task_execution_display.rs +++ b/crates/goose-cli/src/session/task_execution_display.rs @@ -101,6 +101,15 @@ fn format_task_from_json(task: &Value) -> String { } } + if status == "Completed" { + if let Some(result_data) = task.get("result_data") { + let result_preview = format_result_data_for_display(result_data); + if !result_preview.is_empty() { + task_display.push_str(&format!(" 📄 {}{}\n", result_preview, CLEAR_TO_EOL)); + } + } + } + if status == "Failed" { if let Some(error) = task.get("error").and_then(|v| v.as_str()) { let error_preview = truncate_with_ellipsis(error, 80); @@ -116,6 +125,25 @@ fn format_task_from_json(task: &Value) -> String { task_display } +fn format_result_data_for_display(result_data: &Value) -> String { + match result_data { + Value::String(s) => strip_ansi_codes(s), + Value::Object(obj) => { + // Handle specific result formats + if let Some(partial_output) = obj.get("partial_output").and_then(|v| v.as_str()) { + format!("Partial output: {}", partial_output) + } else { + // Generic object display + serde_json::to_string_pretty(obj).unwrap_or_default() + } + } + Value::Array(arr) => serde_json::to_string_pretty(arr).unwrap_or_default(), + Value::Bool(b) => b.to_string(), + Value::Number(n) => n.to_string(), + Value::Null => "null".to_string(), + } +} + fn process_output_for_display(output: &str) -> String { const MAX_OUTPUT_LINES: usize = 2; const OUTPUT_PREVIEW_LENGTH: usize = 100; @@ -127,7 +155,7 @@ fn process_output_for_display(output: &str) -> String { &lines }; - let clean_output = recent_lines.join(" | "); + let clean_output = recent_lines.join(" ... "); let stripped = strip_ansi_codes(&clean_output); truncate_with_ellipsis(&stripped, OUTPUT_PREVIEW_LENGTH) } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs index 9d18f89ba7fa..e3769dd84760 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -15,7 +15,7 @@ pub enum DisplayMode { SingleTaskOutput, } -const THROTTLE_INTERVAL_MS: u64 = 1000; +const THROTTLE_INTERVAL_MS: u64 = 250; fn format_task_metadata(task_info: &TaskInfo) -> String { if let Some(params) = task_info.task.get_command_parameters() { @@ -85,7 +85,7 @@ impl TaskExecutionTracker { task_info.start_time = Some(Instant::now()); } drop(tasks); - self.refresh_display().await; + self.force_refresh_display().await; } pub async fn complete_task(&self, task_id: &str, result: TaskResult) { @@ -96,7 +96,7 @@ impl TaskExecutionTracker { task_info.result = Some(result); } drop(tasks); - self.refresh_display().await; + self.force_refresh_display().await; } pub async fn get_current_output(&self, task_id: &str) -> Option { @@ -189,7 +189,8 @@ impl TaskExecutionTracker { "task_type": task_info.task.task_type, "task_name": get_task_name(task_info), "task_metadata": format_task_metadata(task_info), - "error": task_info.error() + "error": task_info.error(), + "result_data": task_info.data() }) }).collect::>() } @@ -209,6 +210,23 @@ impl TaskExecutionTracker { } } + // Force refresh without throttling - used for important status changes + async fn force_refresh_display(&self) { + match self.display_mode { + DisplayMode::MultipleTasksOutput => { + // Reset throttle timer to allow immediate update + let mut last_refresh = self.last_refresh.write().await; + *last_refresh = Instant::now() - Duration::from_millis(THROTTLE_INTERVAL_MS + 1); + drop(last_refresh); + + self.send_tasks_update().await; + } + DisplayMode::SingleTaskOutput => { + // No dashboard display needed for single task output mode + } + } + } + pub async fn send_tasks_complete(&self) { let tasks = self.tasks.read().await; let (total, _, _, completed, failed) = count_by_status(&tasks); From 2ccfb6809d53194335f0e297875d9dad950f2380 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 14:12:46 +1000 Subject: [PATCH 17/43] test fmt --- crates/goose/src/agents/recipe_tools/param_utils/tests.rs | 1 - .../goose/src/agents/sub_recipe_execution_tool/utils/tests.rs | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs index e7b3f6878eb9..fcbab0f6e9fb 100644 --- a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -4,7 +4,6 @@ mod tests { use crate::recipe::{Execution, ExecutionRun, SubRecipe}; use serde_json::json; - use serde_json::Value; use crate::agents::recipe_tools::param_utils::prepare_command_params; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs index e05e00d3ee13..0e0500c490ec 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -1,7 +1,5 @@ use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskStatus}; -use crate::agents::sub_recipe_execution_tool::utils::{ - count_by_status, format_task_timing, get_status_icon, get_task_name, -}; +use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; use serde_json::json; use std::collections::HashMap; From 445399892a893d245f29f78b8e4987ba3c28f627 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 16:11:29 +1000 Subject: [PATCH 18/43] refactored the code --- .../src/session/task_execution_display.rs | 19 +++-- .../agents/recipe_tools/param_utils/mod.rs | 9 +-- .../agents/recipe_tools/param_utils/tests.rs | 11 ++- .../agents/recipe_tools/sub_recipe_tools.rs | 45 +++++++++--- .../sub_recipe_execution_tool/executor.rs | 7 +- .../agents/sub_recipe_execution_tool/lib.rs | 69 +++++++++++-------- .../task_execution_tracker.rs | 22 ++++-- .../agents/sub_recipe_execution_tool/tasks.rs | 45 ++++++++---- .../agents/sub_recipe_execution_tool/types.rs | 15 ++-- .../sub_recipe_execution_tool/utils/mod.rs | 3 - .../sub_recipe_execution_tool/utils/tests.rs | 67 ++++++------------ .../sub_recipe_execution_tool/workers.rs | 2 +- 12 files changed, 173 insertions(+), 141 deletions(-) diff --git a/crates/goose-cli/src/session/task_execution_display.rs b/crates/goose-cli/src/session/task_execution_display.rs index 38a168241ac9..547c41b09e88 100644 --- a/crates/goose-cli/src/session/task_execution_display.rs +++ b/crates/goose-cli/src/session/task_execution_display.rs @@ -1,4 +1,5 @@ use serde_json::Value; +use std::sync::atomic::{AtomicBool, Ordering}; const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H"; const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H"; @@ -6,19 +7,17 @@ const CLEAR_TO_EOL: &str = "\x1b[K"; const CLEAR_BELOW: &str = "\x1b[J"; pub const TASK_EXECUTION_NOTIFICATION_TYPE: &str = "task_execution"; +static INITIAL_SHOWN: AtomicBool = AtomicBool::new(false); + pub fn format_tasks_update(data: &Value) -> String { let mut display = String::new(); - static mut INITIAL_SHOWN: bool = false; - unsafe { - if !INITIAL_SHOWN { - display.push_str(CLEAR_SCREEN); - display.push_str("🎯 Task Execution Dashboard\n"); - display.push_str("═══════════════════════════\n\n"); - INITIAL_SHOWN = true; - } else { - display.push_str(MOVE_TO_PROGRESS_LINE); - } + if !INITIAL_SHOWN.swap(true, Ordering::SeqCst) { + display.push_str(CLEAR_SCREEN); + display.push_str("🎯 Task Execution Dashboard\n"); + display.push_str("═══════════════════════════\n\n"); + } else { + display.push_str(MOVE_TO_PROGRESS_LINE); } if let Some(stats) = data.get("stats") { diff --git a/crates/goose/src/agents/recipe_tools/param_utils/mod.rs b/crates/goose/src/agents/recipe_tools/param_utils/mod.rs index 08db06c5e25e..e7891c1bb5e7 100644 --- a/crates/goose/src/agents/recipe_tools/param_utils/mod.rs +++ b/crates/goose/src/agents/recipe_tools/param_utils/mod.rs @@ -32,10 +32,11 @@ pub fn validate_param_counts( run_params: &[HashMap], params_from_tool_call: &[Value], ) -> Result<()> { - if !run_params.is_empty() - && run_params.len() != params_from_tool_call.len() - && params_from_tool_call.len() > 1 - { + let has_run_params = !run_params.is_empty(); + let multiple_params_from_tool_call = params_from_tool_call.len() > 1; + let count_mismatch = run_params.len() != params_from_tool_call.len(); + + if has_run_params && multiple_params_from_tool_call && count_mismatch { return Err(anyhow::anyhow!( "The number of runs in the sub recipe ({}) does not match the number of task parameters ({})", run_params.len(), diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs index fcbab0f6e9fb..0a52da64f428 100644 --- a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -1,11 +1,9 @@ -#[cfg(test)] -mod tests { - use std::collections::HashMap; +use std::collections::HashMap; - use crate::recipe::{Execution, ExecutionRun, SubRecipe}; - use serde_json::json; +use crate::recipe::{Execution, ExecutionRun, SubRecipe}; +use serde_json::json; - use crate::agents::recipe_tools::param_utils::prepare_command_params; +use crate::agents::recipe_tools::param_utils::prepare_command_params; fn setup_default_sub_recipe() -> SubRecipe { let sub_recipe = SubRecipe { @@ -233,4 +231,3 @@ mod tests { } } } -} diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 88b100c5e67f..dac0825beda2 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -11,6 +11,8 @@ use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubReci use super::param_utils::prepare_command_params; pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; +const EXECUTION_MODE_PARALLEL: &str = "parallel"; +const EXECUTION_MODE_SEQUENTIAL: &str = "sequential"; pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { let input_schema = get_input_schema(sub_recipe).unwrap(); @@ -37,14 +39,19 @@ pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { ) } -pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result { +fn extract_task_parameters(params: &Value) -> &Vec { let empty_vec = vec![]; - let task_params_array = params + params .get("task_parameters") .and_then(|v| v.as_array()) - .unwrap_or(&empty_vec); - let command_params = prepare_command_params(sub_recipe, task_params_array.clone())?; - let tasks = command_params + .unwrap_or(&empty_vec) +} + +fn create_tasks_from_params( + sub_recipe: &SubRecipe, + command_params: &[std::collections::HashMap], +) -> Vec { + command_params .iter() .map(|task_command_param| { let payload = json!({ @@ -61,16 +68,36 @@ pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Re payload, } }) - .collect::>(); + .collect() +} + +fn get_execution_mode(sub_recipe: &SubRecipe) -> &'static str { let is_parallel = sub_recipe .executions .as_ref() .map(|e| e.parallel) .unwrap_or(false); - let task_execution_payload = json!({ + + if is_parallel { + EXECUTION_MODE_PARALLEL + } else { + EXECUTION_MODE_SEQUENTIAL + } +} + +fn create_task_execution_payload(tasks: Vec, execution_mode: &str) -> Value { + json!({ "tasks": tasks, - "execution_mode": if is_parallel { "parallel" } else { "sequential" } - }); + "execution_mode": execution_mode + }) +} + +pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result { + let task_params_array = extract_task_parameters(¶ms); + let command_params = prepare_command_params(sub_recipe, task_params_array.clone())?; + let tasks = create_tasks_from_params(sub_recipe, &command_params); + let execution_mode = get_execution_mode(sub_recipe); + let task_execution_payload = create_task_execution_payload(tasks, execution_mode); let tasks_json = serde_json::to_string(&task_execution_payload) .map_err(|e| anyhow::anyhow!("Failed to serialize task list: {}", e))?; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs index 4ec7443a40b4..674bd13ae5e9 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs @@ -58,7 +58,7 @@ pub async fn execute_tasks_in_parallel( let (task_tx, task_rx, result_tx, mut result_rx) = create_channels(task_count); if let Err(e) = send_tasks_to_channel(tasks, task_tx).await { - eprintln!("Execution failed: {}", e); + tracing::error!("Task execution failed: {}", e); return create_error_response(e); } @@ -76,7 +76,7 @@ pub async fn execute_tasks_in_parallel( for handle in worker_handles { if let Err(e) = handle.await { - eprintln!("Worker error: {}", e); + tracing::error!("Worker error: {}", e); } } @@ -180,7 +180,8 @@ async fn collect_results( results } -fn create_error_response(_error: String) -> ExecutionResponse { +fn create_error_response(error: String) -> ExecutionResponse { + tracing::error!("Creating error response: {}", error); ExecutionResponse { status: "failed".to_string(), results: vec![], diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs index d3b1f51bb517..e1e20d772969 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs @@ -44,46 +44,55 @@ pub async fn execute_tasks( } } -fn handle_response(response: ExecutionResponse) -> Result { - if response.stats.failed > 0 { - let failed_tasks: Vec = response - .results - .iter() - .filter(|r| matches!(r.status, TaskStatus::Failed)) - .map(|r| { - let error_msg = r.error.as_deref().unwrap_or("Unknown error"); - let partial_output = r - .data - .as_ref() - .and_then(|d| d.get("partial_output")) - .and_then(|v| v.as_str()) - .filter(|s| !s.trim().is_empty()) - .unwrap_or("No output captured"); +fn extract_failed_tasks(results: &[TaskResult]) -> Vec { + results + .iter() + .filter(|r| matches!(r.status, TaskStatus::Failed)) + .map(|r| format_failed_task_error(r)) + .collect() +} + +fn format_failed_task_error(result: &TaskResult) -> String { + let error_msg = result.error.as_deref().unwrap_or("Unknown error"); + let partial_output = result + .data + .as_ref() + .and_then(|d| d.get("partial_output")) + .and_then(|v| v.as_str()) + .filter(|s| !s.trim().is_empty()) + .unwrap_or("No output captured"); + + format!( + "Task '{}' ({}): {}\nOutput: {}", + result.task_id, + get_task_description(result), + error_msg, + partial_output + ) +} - format!( - "Task '{}' ({}): {}, \noutput: {}", - r.task_id, - get_task_description(r), - error_msg, - partial_output - ) - }) - .collect(); +fn format_error_summary(failed_count: usize, total_count: usize, failed_tasks: Vec) -> String { + format!( + "{}/{} tasks failed:\n{}", + failed_count, + total_count, + failed_tasks.join("\n") + ) +} - let error_summary = format!( - "{}/{} tasks failed:\n{}", +fn handle_response(response: ExecutionResponse) -> Result { + if response.stats.failed > 0 { + let failed_tasks = extract_failed_tasks(&response.results); + let error_summary = format_error_summary( response.stats.failed, response.stats.total_tasks, - failed_tasks.join("\n") + failed_tasks, ); - return Err(error_summary); } serde_json::to_value(response).map_err(|e| format!("Failed to serialize response: {}", e)) } fn get_task_description(result: &TaskResult) -> String { - // We'd need to reconstruct task info from the result or pass it through - // For now, just use the task_id as placeholder format!("ID: {}", result.task_id) } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs index e3769dd84760..96176fc77930 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -16,6 +16,7 @@ pub enum DisplayMode { } const THROTTLE_INTERVAL_MS: u64 = 250; +const COMPLETION_NOTIFICATION_DELAY_MS: u64 = 500; fn format_task_metadata(task_info: &TaskInfo) -> String { if let Some(params) = task_info.task.get_command_parameters() { @@ -110,7 +111,7 @@ impl TaskExecutionTracker { match self.display_mode { DisplayMode::SingleTaskOutput => { // Send raw output data - let subscriber format it - let _ = self + if let Err(e) = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { jsonrpc: "2.0".to_string(), @@ -123,7 +124,9 @@ impl TaskExecutionTracker { "output": line } })), - })); + })) { + tracing::warn!("Failed to send live output notification: {}", e); + } } DisplayMode::MultipleTasksOutput => { let mut tasks = self.tasks.write().await; @@ -157,7 +160,7 @@ impl TaskExecutionTracker { let task_list: Vec<_> = tasks.values().collect(); let (total, pending, running, completed, failed) = count_by_status(&tasks); - let _ = self + if let Err(e) = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { jsonrpc: "2.0".to_string(), @@ -195,7 +198,9 @@ impl TaskExecutionTracker { }).collect::>() } })), - })); + })) { + tracing::warn!("Failed to send tasks update notification: {}", e); + } } pub async fn refresh_display(&self) { @@ -244,7 +249,7 @@ impl TaskExecutionTracker { }) .collect(); - let _ = self + if let Err(e) = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { jsonrpc: "2.0".to_string(), @@ -262,8 +267,11 @@ impl TaskExecutionTracker { "failed_tasks": failed_tasks } })), - })); + })) { + tracing::warn!("Failed to send tasks complete notification: {}", e); + } - sleep(Duration::from_millis(500)).await; + // Brief delay to ensure completion notification is processed + sleep(Duration::from_millis(COMPLETION_NOTIFICATION_DELAY_MS)).await; } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index 428f7bc81f8a..43443804ddd9 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -9,11 +9,13 @@ use tokio::time::timeout; use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult, TaskStatus}; +const DEFAULT_TASK_TIMEOUT_SECONDS: u64 = 300; + pub async fn process_task( task: &Task, task_execution_tracker: Arc, ) -> TaskResult { - let timeout_in_seconds = task.timeout_in_seconds.unwrap_or(300); + let timeout_in_seconds = task.timeout_in_seconds.unwrap_or(DEFAULT_TASK_TIMEOUT_SECONDS); let task_clone = task.clone(); let timeout_duration = Duration::from_secs(timeout_in_seconds); @@ -75,17 +77,19 @@ async fn get_task_result( } fn build_command(task: &Task) -> Result<(Command, String), String> { + let task_error = |field: &str| format!("Task {}: Missing {}", task.id, field); + let mut output_identifier = task.id.clone(); let mut command = if task.task_type == "sub_recipe" { let sub_recipe_name = task .get_sub_recipe_name() - .ok_or("Missing sub_recipe name")?; + .ok_or_else(|| task_error("sub_recipe name"))?; let path = task .get_sub_recipe_path() - .ok_or("Missing sub_recipe path")?; + .ok_or_else(|| task_error("sub_recipe path"))?; let command_parameters = task .get_command_parameters() - .ok_or("Missing command_parameters")?; + .ok_or_else(|| task_error("command_parameters"))?; output_identifier = format!("sub-recipe {}", sub_recipe_name); let mut cmd = Command::new("goose"); @@ -101,7 +105,7 @@ fn build_command(task: &Task) -> Result<(Command, String), String> { } else { let text = task .get_text_instruction() - .ok_or("Missing text_instruction")?; + .ok_or_else(|| task_error("text_instruction"))?; let mut cmd = Command::new("goose"); cmd.arg("run").arg("--text").arg(text); cmd @@ -172,13 +176,29 @@ fn spawn_output_reader( .send_live_output(&task_id, &line) .await; } else { - eprintln!("[stderr for {}] {}", output_identifier, line); + tracing::warn!("Task stderr [{}]: {}", output_identifier, line); } } buffer }) } +fn extract_json_from_line(line: &str) -> Option { + let start = line.find('{')?; + let end = line.rfind('}')?; + + if start >= end { + return None; + } + + let potential_json = &line[start..=end]; + if serde_json::from_str::(potential_json).is_ok() { + Some(potential_json.to_string()) + } else { + None + } +} + fn process_output(stdout_output: String) -> Result { let last_line = stdout_output .lines() @@ -186,14 +206,9 @@ fn process_output(stdout_output: String) -> Result { .next_back() .unwrap_or(""); - if let (Some(start), Some(end)) = (last_line.find('{'), last_line.rfind('}')) { - if start < end { - let potential_json = &last_line[start..=end]; - - if serde_json::from_str::(potential_json).is_ok() { - return Ok(Value::String(potential_json.to_string())); - } - } + if let Some(json_string) = extract_json_from_line(last_line) { + Ok(Value::String(json_string)) + } else { + Ok(Value::String(stdout_output)) } - Ok(Value::String(stdout_output)) } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index 4184eb65ec36..e10c5adb889e 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -16,11 +16,9 @@ pub struct Task { impl Task { pub fn get_sub_recipe(&self) -> Option<&Map> { - if self.task_type == "sub_recipe" { - self.payload.get("sub_recipe").and_then(|sr| sr.as_object()) - } else { - None - } + (self.task_type == "sub_recipe") + .then(|| self.payload.get("sub_recipe")?.as_object()) + .flatten() } pub fn get_command_parameters(&self) -> Option<&Map> { @@ -124,12 +122,15 @@ impl Default for Config { } } +const DEFAULT_MAX_WORKERS: usize = 10; +const DEFAULT_INITIAL_WORKERS: usize = 2; + fn default_max_workers() -> usize { - 10 + DEFAULT_MAX_WORKERS } fn default_initial_workers() -> usize { - 2 + DEFAULT_INITIAL_WORKERS } #[derive(Debug, Serialize)] diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index bf806a50d182..e14dbb2f20bb 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -10,9 +10,6 @@ pub fn get_task_name(task_info: &TaskInfo) -> &str { .unwrap_or(&task_info.task.id) } -pub fn get_command_parameters(task_info: &TaskInfo) -> Option<&Map> { - task_info.task.get_command_parameters() -} pub fn count_by_status(tasks: &HashMap) -> (usize, usize, usize, usize, usize) { let total = tasks.len(); diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs index 0e0500c490ec..618026a7f003 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -3,6 +3,17 @@ use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_ use serde_json::json; use std::collections::HashMap; +fn create_task_info_with_defaults(task: Task, status: TaskStatus) -> TaskInfo { + TaskInfo { + task, + status, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + } +} + mod test_get_task_name { use super::*; @@ -20,14 +31,7 @@ mod test_get_task_name { }), }; - let task_info = TaskInfo { - task: sub_recipe_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; + let task_info = create_task_info_with_defaults(sub_recipe_task, TaskStatus::Pending); assert_eq!(get_task_name(&task_info), "my_recipe"); } @@ -41,14 +45,7 @@ mod test_get_task_name { payload: json!({"text_instruction": "do something"}), }; - let task_info = TaskInfo { - task: text_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; + let task_info = create_task_info_with_defaults(text_task, TaskStatus::Pending); assert_eq!(get_task_name(&task_info), "task_2"); } @@ -67,14 +64,7 @@ mod test_get_task_name { }), }; - let task_info = TaskInfo { - task: malformed_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; + let task_info = create_task_info_with_defaults(malformed_task, TaskStatus::Pending); assert_eq!(get_task_name(&task_info), "task_3"); } @@ -88,14 +78,7 @@ mod test_get_task_name { payload: json!({}), // missing "sub_recipe" field }; - let task_info = TaskInfo { - task: malformed_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; + let task_info = create_task_info_with_defaults(malformed_task, TaskStatus::Pending); assert_eq!(get_task_name(&task_info), "task_4"); } @@ -105,19 +88,13 @@ mod count_by_status { use super::*; fn create_test_task(id: &str, status: TaskStatus) -> TaskInfo { - TaskInfo { - task: Task { - id: id.to_string(), - task_type: "test".to_string(), - timeout_in_seconds: None, - payload: json!({}), - }, - status, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - } + let task = Task { + id: id.to_string(), + task_type: "test".to_string(), + timeout_in_seconds: None, + payload: json!({}), + }; + create_task_info_with_defaults(task, status) } #[test] diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index f891595456b5..35e9f6d22219 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -21,7 +21,7 @@ async fn worker_loop(state: Arc, _worker_id: usize) { let result = process_task(&task, state.task_execution_tracker.clone()).await; if let Err(e) = state.result_sender.send(result).await { - eprintln!("Worker failed to send result: {}", e); + tracing::error!("Worker failed to send result: {}", e); break; } } From c14da819a16a105b70f4387a49f00c10ab4722d6 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 16:20:32 +1000 Subject: [PATCH 19/43] refactored parsing recipe --- .../agents/recipe_tools/sub_recipe_tools.rs | 6 ++--- .../sub_recipe_execution_tool/utils/mod.rs | 1 - crates/goose/src/recipe/mod.rs | 24 ++++++------------- 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index dac0825beda2..db4b3fed6c08 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -39,12 +39,12 @@ pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { ) } -fn extract_task_parameters(params: &Value) -> &Vec { - let empty_vec = vec![]; +fn extract_task_parameters(params: &Value) -> Vec { params .get("task_parameters") .and_then(|v| v.as_array()) - .unwrap_or(&empty_vec) + .cloned() + .unwrap_or_default() } fn create_tasks_from_params( diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index e14dbb2f20bb..c2e39b15ade3 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -1,4 +1,3 @@ -use serde_json::{Map, Value}; use std::collections::HashMap; use crate::agents::sub_recipe_execution_tool::types::{TaskInfo, TaskStatus}; diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 525433221ef2..144c7831b32c 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -281,23 +281,13 @@ impl Recipe { } } pub fn from_content(content: &str) -> Result { - if let Ok(json_value) = serde_json::from_str::(content) { - if let Some(nested_recipe) = json_value.get("recipe") { - Ok(serde_json::from_value(nested_recipe.clone())?) - } else { - Ok(serde_json::from_str(content)?) - } - } else if let Ok(yaml_value) = serde_yaml::from_str::(content) { - if let Some(nested_recipe) = yaml_value.get("recipe") { - Ok(serde_yaml::from_value(nested_recipe.clone())?) - } else { - Ok(serde_yaml::from_str(content)?) - } - } else { - Err(anyhow::anyhow!( - "Unsupported format. Expected JSON or YAML." - )) - } + let yaml_value = serde_yaml::from_str::(content) + .map_err(|_| anyhow::anyhow!("Unsupported format. Expected JSON or YAML."))?; + + // Check if there's a nested "recipe" key, otherwise use the root + let recipe_value = yaml_value.get("recipe").unwrap_or(&yaml_value); + + Ok(serde_yaml::from_value(recipe_value.clone())?) } } From a01d7127bf9c69013ba5f77b291223fe33f5d569 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 17:39:48 +1000 Subject: [PATCH 20/43] created taskExecutionNotificationEvent --- .../src/session/task_execution_display.rs | 319 ++++++-------- .../agents/recipe_tools/param_utils/tests.rs | 415 +++++++++--------- .../agents/recipe_tools/sub_recipe_tools.rs | 2 +- .../agents/sub_recipe_execution_tool/lib.rs | 18 +- .../agents/sub_recipe_execution_tool/mod.rs | 1 + .../notification_events.rs | 204 +++++++++ .../sub_recipe_execute_task_tool.rs | 7 +- .../task_execution_tracker.rs | 129 +++--- .../agents/sub_recipe_execution_tool/tasks.rs | 10 +- .../agents/sub_recipe_execution_tool/types.rs | 19 + .../sub_recipe_execution_tool/utils/mod.rs | 1 - crates/goose/src/recipe/mod.rs | 24 +- 12 files changed, 668 insertions(+), 481 deletions(-) create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs diff --git a/crates/goose-cli/src/session/task_execution_display.rs b/crates/goose-cli/src/session/task_execution_display.rs index 547c41b09e88..f8bad819f4d7 100644 --- a/crates/goose-cli/src/session/task_execution_display.rs +++ b/crates/goose-cli/src/session/task_execution_display.rs @@ -1,3 +1,7 @@ +use goose::agents::sub_recipe_execution_tool::lib::TaskStatus; +use goose::agents::sub_recipe_execution_tool::notification_events::{ + TaskExecutionNotificationEvent, TaskInfo, +}; use serde_json::Value; use std::sync::atomic::{AtomicBool, Ordering}; @@ -9,130 +13,13 @@ pub const TASK_EXECUTION_NOTIFICATION_TYPE: &str = "task_execution"; static INITIAL_SHOWN: AtomicBool = AtomicBool::new(false); -pub fn format_tasks_update(data: &Value) -> String { - let mut display = String::new(); - - if !INITIAL_SHOWN.swap(true, Ordering::SeqCst) { - display.push_str(CLEAR_SCREEN); - display.push_str("🎯 Task Execution Dashboard\n"); - display.push_str("═══════════════════════════\n\n"); - } else { - display.push_str(MOVE_TO_PROGRESS_LINE); - } - - if let Some(stats) = data.get("stats") { - let total = stats.get("total").and_then(|v| v.as_u64()).unwrap_or(0); - let pending = stats.get("pending").and_then(|v| v.as_u64()).unwrap_or(0); - let running = stats.get("running").and_then(|v| v.as_u64()).unwrap_or(0); - let completed = stats.get("completed").and_then(|v| v.as_u64()).unwrap_or(0); - let failed = stats.get("failed").and_then(|v| v.as_u64()).unwrap_or(0); - - display.push_str(&format!( - "📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed", - total, pending, running, completed, failed - )); - display.push_str(&format!("{}\n\n", CLEAR_TO_EOL)); - } - - if let Some(tasks) = data.get("tasks").and_then(|t| t.as_array()) { - let mut sorted_tasks: Vec<_> = tasks.iter().collect(); - sorted_tasks.sort_by_key(|task| task.get("id").and_then(|v| v.as_str()).unwrap_or("")); - - for task in sorted_tasks { - display.push_str(&format_task_from_json(task)); - } - } - - display.push_str(CLEAR_BELOW); - display -} - -fn format_task_from_json(task: &Value) -> String { - let mut task_display = String::new(); - - let id = task.get("id").and_then(|v| v.as_str()).unwrap_or("unknown"); - let status = task - .get("status") - .and_then(|v| v.as_str()) - .unwrap_or("unknown"); - let task_type = task - .get("task_type") - .and_then(|v| v.as_str()) - .unwrap_or("task"); - let task_name = task.get("task_name").and_then(|v| v.as_str()).unwrap_or(id); - let task_metadata = task - .get("task_metadata") - .and_then(|v| v.as_str()) - .unwrap_or(""); - - let status_icon = match status { - "Pending" => "⏳", - "Running" => "🏃", - "Completed" => "✅", - "Failed" => "❌", - _ => "◯", - }; - - task_display.push_str(&format!( - "{} {} ({}){}\n", - status_icon, task_name, task_type, CLEAR_TO_EOL - )); - - if !task_metadata.is_empty() { - task_display.push_str(&format!( - " 📋 Parameters: {}{}\n", - task_metadata, CLEAR_TO_EOL - )); - } - - if let Some(duration_secs) = task.get("duration_secs").and_then(|v| v.as_f64()) { - task_display.push_str(&format!(" ⏱️ {:.1}s{}\n", duration_secs, CLEAR_TO_EOL)); - } - - if status == "Running" { - if let Some(current_output) = task.get("current_output").and_then(|v| v.as_str()) { - if !current_output.trim().is_empty() { - let processed_output = process_output_for_display(current_output); - if !processed_output.is_empty() { - task_display.push_str(&format!(" 💬 {}{}\n", processed_output, CLEAR_TO_EOL)); - } - } - } - } - - if status == "Completed" { - if let Some(result_data) = task.get("result_data") { - let result_preview = format_result_data_for_display(result_data); - if !result_preview.is_empty() { - task_display.push_str(&format!(" 📄 {}{}\n", result_preview, CLEAR_TO_EOL)); - } - } - } - - if status == "Failed" { - if let Some(error) = task.get("error").and_then(|v| v.as_str()) { - let error_preview = truncate_with_ellipsis(error, 80); - task_display.push_str(&format!( - " ⚠️ {}{}\n", - error_preview.replace('\n', " "), - CLEAR_TO_EOL - )); - } - } - - task_display.push_str(&format!("{}\n", CLEAR_TO_EOL)); - task_display -} - fn format_result_data_for_display(result_data: &Value) -> String { match result_data { Value::String(s) => strip_ansi_codes(s), Value::Object(obj) => { - // Handle specific result formats if let Some(partial_output) = obj.get("partial_output").and_then(|v| v.as_str()) { format!("Partial output: {}", partial_output) } else { - // Generic object display serde_json::to_string_pretty(obj).unwrap_or_default() } } @@ -190,84 +77,154 @@ fn strip_ansi_codes(text: &str) -> String { result } -pub fn format_tasks_complete(data: &Value) -> String { - let mut summary = String::new(); - summary.push_str("Execution Complete!\n"); - summary.push_str("═══════════════════════\n"); +pub fn format_task_execution_notification( + data: &Value, +) -> Option<(String, Option, Option)> { + if let Ok(event) = serde_json::from_value::(data.clone()) { + return Some(match event { + TaskExecutionNotificationEvent::LineOutput { output, .. } => ( + format!("{}\n", output), + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ), + TaskExecutionNotificationEvent::TasksUpdate { .. } => { + let formatted_display = format_tasks_update_from_event(&event); + ( + formatted_display, + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ) + } + TaskExecutionNotificationEvent::TasksComplete { .. } => { + let formatted_summary = format_tasks_complete_from_event(&event); + ( + formatted_summary, + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ) + } + }); + } + None +} + +fn format_tasks_update_from_event(event: &TaskExecutionNotificationEvent) -> String { + if let TaskExecutionNotificationEvent::TasksUpdate { stats, tasks } = event { + let mut display = String::new(); + + if !INITIAL_SHOWN.swap(true, Ordering::SeqCst) { + display.push_str(CLEAR_SCREEN); + display.push_str("🎯 Task Execution Dashboard\n"); + display.push_str("═══════════════════════════\n\n"); + } else { + display.push_str(MOVE_TO_PROGRESS_LINE); + } + + display.push_str(&format!( + "📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed", + stats.total, stats.pending, stats.running, stats.completed, stats.failed + )); + display.push_str(&format!("{}\n\n", CLEAR_TO_EOL)); - if let Some(stats) = data.get("stats") { - let total = stats.get("total").and_then(|v| v.as_u64()).unwrap_or(0); - let completed = stats.get("completed").and_then(|v| v.as_u64()).unwrap_or(0); - let failed = stats.get("failed").and_then(|v| v.as_u64()).unwrap_or(0); - let success_rate = stats - .get("success_rate") - .and_then(|v| v.as_f64()) - .unwrap_or(0.0); + let mut sorted_tasks = tasks.clone(); + sorted_tasks.sort_by(|a, b| a.id.cmp(&b.id)); - summary.push_str(&format!("Total Tasks: {}\n", total)); - summary.push_str(&format!("✅ Completed: {}\n", completed)); - summary.push_str(&format!("❌ Failed: {}\n", failed)); - summary.push_str(&format!("📈 Success Rate: {:.1}%\n", success_rate)); + for task in sorted_tasks { + display.push_str(&format_task_from_struct(&task)); + } + + display.push_str(CLEAR_BELOW); + display + } else { + String::new() } +} + +fn format_tasks_complete_from_event(event: &TaskExecutionNotificationEvent) -> String { + if let TaskExecutionNotificationEvent::TasksComplete { + stats, + failed_tasks, + } = event + { + let mut summary = String::new(); + summary.push_str("Execution Complete!\n"); + summary.push_str("═══════════════════════\n"); + + summary.push_str(&format!("Total Tasks: {}\n", stats.total)); + summary.push_str(&format!("✅ Completed: {}\n", stats.completed)); + summary.push_str(&format!("❌ Failed: {}\n", stats.failed)); + summary.push_str(&format!("📈 Success Rate: {:.1}%\n", stats.success_rate)); - if let Some(failed_tasks) = data.get("failed_tasks").and_then(|t| t.as_array()) { if !failed_tasks.is_empty() { summary.push_str("\n❌ Failed Tasks:\n"); for task in failed_tasks { - let name = task - .get("name") - .and_then(|v| v.as_str()) - .unwrap_or("Unknown"); - summary.push_str(&format!(" • {}\n", name)); - if let Some(error) = task.get("error").and_then(|v| v.as_str()) { + summary.push_str(&format!(" • {}\n", task.name)); + if let Some(error) = &task.error { summary.push_str(&format!(" Error: {}\n", error)); } } } - } - summary.push_str("\n📝 Generating summary...\n"); - summary + summary.push_str("\n📝 Generating summary...\n"); + summary + } else { + String::new() + } } -pub fn format_task_execution_notification( - data: &Value, -) -> Option<(String, Option, Option)> { - if let Value::Object(o) = data { - if o.get("type").and_then(|t| t.as_str()) == Some(TASK_EXECUTION_NOTIFICATION_TYPE) { - return Some(match o.get("subtype").and_then(|t| t.as_str()) { - Some("line_output") => { - if let Some(Value::String(line_output)) = o.get("output") { - ( - format!("{}\n", line_output), - None, - Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), - ) - } else { - (data.to_string(), None, None) - } - } - Some("tasks_update") => { - let data_value = Value::Object(o.clone()); - let formatted_display = format_tasks_update(&data_value); - ( - formatted_display, - None, - Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), - ) - } - Some("tasks_complete") => { - let data_value = Value::Object(o.clone()); - let formatted_summary = format_tasks_complete(&data_value); - ( - formatted_summary, - None, - Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), - ) - } - _ => (data.to_string(), None, None), - }); +fn format_task_from_struct(task: &TaskInfo) -> String { + let mut task_display = String::new(); + + let status_icon = match task.status { + TaskStatus::Pending => "⏳", + TaskStatus::Running => "🏃", + TaskStatus::Completed => "✅", + TaskStatus::Failed => "❌", + }; + + task_display.push_str(&format!( + "{} {} ({}){}\n", + status_icon, task.task_name, task.task_type, CLEAR_TO_EOL + )); + + if !task.task_metadata.is_empty() { + task_display.push_str(&format!( + " 📋 Parameters: {}{}\n", + task.task_metadata, CLEAR_TO_EOL + )); + } + + if let Some(duration_secs) = task.duration_secs { + task_display.push_str(&format!(" ⏱️ {:.1}s{}\n", duration_secs, CLEAR_TO_EOL)); + } + + if matches!(task.status, TaskStatus::Running) && !task.current_output.trim().is_empty() { + let processed_output = process_output_for_display(&task.current_output); + if !processed_output.is_empty() { + task_display.push_str(&format!(" 💬 {}{}\n", processed_output, CLEAR_TO_EOL)); } } - None + + if matches!(task.status, TaskStatus::Completed) { + if let Some(result_data) = &task.result_data { + let result_preview = format_result_data_for_display(result_data); + if !result_preview.is_empty() { + task_display.push_str(&format!(" 📄 {}{}\n", result_preview, CLEAR_TO_EOL)); + } + } + } + + if matches!(task.status, TaskStatus::Failed) { + if let Some(error) = &task.error { + let error_preview = truncate_with_ellipsis(error, 80); + task_display.push_str(&format!( + " ⚠️ {}{}\n", + error_preview.replace('\n', " "), + CLEAR_TO_EOL + )); + } + } + + task_display.push_str(&format!("{}\n", CLEAR_TO_EOL)); + task_display } diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs index 0a52da64f428..9cbec0c7b828 100644 --- a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -5,229 +5,224 @@ use serde_json::json; use crate::agents::recipe_tools::param_utils::prepare_command_params; - fn setup_default_sub_recipe() -> SubRecipe { - let sub_recipe = SubRecipe { - name: "test_sub_recipe".to_string(), - path: "test_sub_recipe.yaml".to_string(), - timeout_in_seconds: None, - values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), - executions: None, - }; - sub_recipe +fn setup_default_sub_recipe() -> SubRecipe { + let sub_recipe = SubRecipe { + name: "test_sub_recipe".to_string(), + path: "test_sub_recipe.yaml".to_string(), + timeout_in_seconds: None, + values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), + executions: None, + }; + sub_recipe +} + +fn create_execution_values(key: &str, values: Vec) -> Execution { + let runs = values + .iter() + .map(|value| ExecutionRun { + values: Some(HashMap::from([(key.to_string(), value.to_string())])), + }) + .collect(); + Execution { + parallel: true, + runs: Some(runs), } +} - fn create_execution_values(key: &str, values: Vec) -> Execution { - let runs = values - .iter() - .map(|value| ExecutionRun { - values: Some(HashMap::from([(key.to_string(), value.to_string())])), - }) - .collect(); - Execution { - parallel: true, - runs: Some(runs), +mod prepare_command_params_tests { + use super::*; + + mod without_execution_runs { + use super::*; + + #[test] + fn test_return_command_param() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "value2".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_command_param_when_value_override_passed_param_value() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "different_value".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_empty_command_param() { + let parameter_array = vec![]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!(result.len(), 0); } } - mod prepare_command_params_tests { + mod with_execution_runs { use super::*; - mod without_execution_runs { - use super::*; - - #[test] - fn test_return_command_param() { - let parameter_array = vec![json!(HashMap::from([( - "key2".to_string(), - "value2".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()) - ]),], - result - ); - } - - #[test] - fn test_return_command_param_when_value_override_passed_param_value() { - let parameter_array = vec![json!(HashMap::from([( - "key2".to_string(), - "different_value".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = Some(HashMap::from([ + #[test] + fn test_return_command_param() { + let parameter_array = vec![json!(HashMap::from([( + "key3".to_string(), + "value3".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = + Some(create_execution_values("key2", vec!["value2".to_string()])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ ("key1".to_string(), "value1".to_string()), ("key2".to_string(), "value2".to_string()), - ])); + ("key3".to_string(), "value3".to_string()) + ]),], + result + ); + } - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![HashMap::from([ + #[test] + fn test_return_command_param_when_all_values_from_tool_call_parameters() { + let parameter_array = vec![ + json!(HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()) + ])), + json!(HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()) + ])), + ]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + sub_recipe.executions = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ]), + ], + result + ); + } + + #[test] + fn test_return_command_param_when_all_from_values_in_sub_recipe() { + let parameter_array = vec![]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string(), "key2_value2".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ]), + HashMap::from([ ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()) - ]),], - result - ); - } - - #[test] - fn test_return_empty_command_param() { - let parameter_array = vec![]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = None; - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!(result.len(), 0); - } + ("key2".to_string(), "key2_value2".to_string()), + ("key3".to_string(), "value3".to_string()), + ]) + ], + result + ); } - mod with_execution_runs { - use super::*; - - #[test] - fn test_return_command_param() { - let parameter_array = vec![json!(HashMap::from([( - "key3".to_string(), - "value3".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - sub_recipe.executions = - Some(create_execution_values("key2", vec!["value2".to_string()])); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![HashMap::from([ + #[test] + fn test_return_command_param_when_tool_call_parameters_has_one_item_and_execution_runs_has_multiple_items( + ) { + let parameter_array = vec![json!(HashMap::from([( + "key3".to_string(), + "value3".to_string() + ),]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string(), "key2_value2".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()), - ("key3".to_string(), "value3".to_string()) - ]),], - result - ); - } - - #[test] - fn test_return_command_param_when_all_values_from_tool_call_parameters() { - let parameter_array = vec![ - json!(HashMap::from([ - ("key1".to_string(), "key1_value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()) - ])), - json!(HashMap::from([ - ("key1".to_string(), "key1_value2".to_string()), - ("key2".to_string(), "key2_value2".to_string()) - ])), - ]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = None; - sub_recipe.executions = None; - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![ - HashMap::from([ - ("key1".to_string(), "key1_value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()), - ]), - HashMap::from([ - ("key1".to_string(), "key1_value2".to_string()), - ("key2".to_string(), "key2_value2".to_string()), - ]), - ], - result - ); - } - - #[test] - fn test_return_command_param_when_all_from_values_in_sub_recipe() { - let parameter_array = vec![]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = Some(HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key3".to_string(), "value3".to_string()), - ])); - sub_recipe.executions = Some(create_execution_values( - "key2", - vec!["key2_value1".to_string(), "key2_value2".to_string()], - )); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![ - HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()), - ("key3".to_string(), "value3".to_string()), - ]), - HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value2".to_string()), - ("key3".to_string(), "value3".to_string()), - ]) - ], - result - ); - } - - #[test] - fn test_return_command_param_when_tool_call_parameters_has_one_item_and_execution_runs_has_multiple_items( - ) { - let parameter_array = vec![json!(HashMap::from([( - "key3".to_string(), - "value3".to_string() - ),]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - sub_recipe.executions = Some(create_execution_values( - "key2", - vec!["key2_value1".to_string(), "key2_value2".to_string()], - )); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![ - HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()), - ("key3".to_string(), "value3".to_string()), - ]), - HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value2".to_string()), - ("key3".to_string(), "value3".to_string()), - ]) - ], - result - ); - } - - #[test] - fn test_throw_error_when_execution_runs_value_length_not_match_with_tool_call_parameters( - ) { - let parameter_array = vec![ - json!(HashMap::from([("key3".to_string(), "value3".to_string())])), - json!(HashMap::from([("key4".to_string(), "value4".to_string())])), - ]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - sub_recipe.executions = Some(create_execution_values( - "key2", - vec!["key2_value1".to_string()], - )); - - let result = prepare_command_params(&sub_recipe, parameter_array); - - assert!(result.is_err()); - } + ("key2".to_string(), "key2_value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ("key3".to_string(), "value3".to_string()), + ]) + ], + result + ); + } + + #[test] + fn test_throw_error_when_execution_runs_value_length_not_match_with_tool_call_parameters() { + let parameter_array = vec![ + json!(HashMap::from([("key3".to_string(), "value3".to_string())])), + json!(HashMap::from([("key4".to_string(), "value4".to_string())])), + ]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array); + + assert!(result.is_err()); } } +} diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index db4b3fed6c08..fd0edaeaecd1 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -77,7 +77,7 @@ fn get_execution_mode(sub_recipe: &SubRecipe) -> &'static str { .as_ref() .map(|e| e.parallel) .unwrap_or(false); - + if is_parallel { EXECUTION_MODE_PARALLEL } else { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs index e1e20d772969..4cded2c71000 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs @@ -2,7 +2,8 @@ use crate::agents::sub_recipe_execution_tool::executor::{ execute_single_task, execute_tasks_in_parallel, }; pub use crate::agents::sub_recipe_execution_tool::types::{ - Config, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, + Config, ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, + TaskStatus, }; use mcp_core::protocol::JsonRpcMessage; @@ -11,7 +12,7 @@ use tokio::sync::mpsc; pub async fn execute_tasks( input: Value, - execution_mode: &str, + execution_mode: ExecutionMode, notifier: mpsc::Sender, ) -> Result { let tasks: Vec = @@ -27,7 +28,7 @@ pub async fn execute_tasks( let task_count = tasks.len(); match execution_mode { - "sequential" => { + ExecutionMode::Sequential => { if task_count == 1 { let response = execute_single_task(&tasks[0], notifier).await; handle_response(response) @@ -35,12 +36,11 @@ pub async fn execute_tasks( Err("Sequential execution mode requires exactly one task".to_string()) } } - "parallel" => { + ExecutionMode::Parallel => { let response: ExecutionResponse = execute_tasks_in_parallel(tasks, config, notifier).await; handle_response(response) } - _ => Err("Invalid execution mode".to_string()), } } @@ -48,7 +48,7 @@ fn extract_failed_tasks(results: &[TaskResult]) -> Vec { results .iter() .filter(|r| matches!(r.status, TaskStatus::Failed)) - .map(|r| format_failed_task_error(r)) + .map(format_failed_task_error) .collect() } @@ -71,7 +71,11 @@ fn format_failed_task_error(result: &TaskResult) -> String { ) } -fn format_error_summary(failed_count: usize, total_count: usize, failed_tasks: Vec) -> String { +fn format_error_summary( + failed_count: usize, + total_count: usize, + failed_tasks: Vec, +) -> String { format!( "{}/{} tasks failed:\n{}", failed_count, diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs index 03568dc53197..267caab8f115 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs @@ -1,5 +1,6 @@ mod executor; pub mod lib; +pub mod notification_events; pub mod sub_recipe_execute_task_tool; mod task_execution_tracker; mod tasks; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs b/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs new file mode 100644 index 000000000000..97a4576b2d99 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs @@ -0,0 +1,204 @@ +use crate::agents::sub_recipe_execution_tool::types::TaskStatus; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "subtype")] +pub enum TaskExecutionNotificationEvent { + #[serde(rename = "line_output")] + LineOutput { task_id: String, output: String }, + #[serde(rename = "tasks_update")] + TasksUpdate { + stats: TaskExecutionStats, + tasks: Vec, + }, + #[serde(rename = "tasks_complete")] + TasksComplete { + stats: TaskCompletionStats, + failed_tasks: Vec, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskExecutionStats { + pub total: usize, + pub pending: usize, + pub running: usize, + pub completed: usize, + pub failed: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskCompletionStats { + pub total: usize, + pub completed: usize, + pub failed: usize, + pub success_rate: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskInfo { + pub id: String, + pub status: TaskStatus, + pub duration_secs: Option, + pub current_output: String, + pub task_type: String, + pub task_name: String, + pub task_metadata: String, + pub error: Option, + pub result_data: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FailedTaskInfo { + pub id: String, + pub name: String, + pub error: Option, +} + +impl TaskExecutionNotificationEvent { + pub fn line_output(task_id: String, output: String) -> Self { + Self::LineOutput { task_id, output } + } + + pub fn tasks_update(stats: TaskExecutionStats, tasks: Vec) -> Self { + Self::TasksUpdate { stats, tasks } + } + + pub fn tasks_complete(stats: TaskCompletionStats, failed_tasks: Vec) -> Self { + Self::TasksComplete { + stats, + failed_tasks, + } + } + + /// Convert event to JSON format for MCP notification + pub fn to_notification_data(&self) -> serde_json::Value { + let mut event_data = serde_json::to_value(self).expect("Failed to serialize event"); + + // Add the type field at the root level + if let serde_json::Value::Object(ref mut map) = event_data { + map.insert( + "type".to_string(), + serde_json::Value::String("task_execution".to_string()), + ); + } + + event_data + } +} + +impl TaskExecutionStats { + pub fn new( + total: usize, + pending: usize, + running: usize, + completed: usize, + failed: usize, + ) -> Self { + Self { + total, + pending, + running, + completed, + failed, + } + } +} + +impl TaskCompletionStats { + pub fn new(total: usize, completed: usize, failed: usize) -> Self { + let success_rate = if total > 0 { + (completed as f64 / total as f64) * 100.0 + } else { + 0.0 + }; + + Self { + total, + completed, + failed, + success_rate, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_line_output_event_serialization() { + let event = TaskExecutionNotificationEvent::line_output( + "task-1".to_string(), + "Hello World".to_string(), + ); + + let notification_data = event.to_notification_data(); + assert_eq!(notification_data["type"], "task_execution"); + assert_eq!(notification_data["subtype"], "line_output"); + assert_eq!(notification_data["task_id"], "task-1"); + assert_eq!(notification_data["output"], "Hello World"); + } + + #[test] + fn test_tasks_update_event_serialization() { + let stats = TaskExecutionStats::new(5, 2, 1, 1, 1); + let tasks = vec![TaskInfo { + id: "task-1".to_string(), + status: TaskStatus::Running, + duration_secs: Some(1.5), + current_output: "Processing...".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "test-task".to_string(), + task_metadata: "param=value".to_string(), + error: None, + result_data: None, + }]; + + let event = TaskExecutionNotificationEvent::tasks_update(stats, tasks); + let notification_data = event.to_notification_data(); + + assert_eq!(notification_data["type"], "task_execution"); + assert_eq!(notification_data["subtype"], "tasks_update"); + assert_eq!(notification_data["stats"]["total"], 5); + assert_eq!(notification_data["tasks"].as_array().unwrap().len(), 1); + } + + #[test] + fn test_event_roundtrip_serialization() { + let original_event = TaskExecutionNotificationEvent::line_output( + "task-1".to_string(), + "Test output".to_string(), + ); + + // Serialize to JSON + let json_data = original_event.to_notification_data(); + + // Deserialize back to event (excluding the type field) + let mut event_data = json_data.clone(); + if let serde_json::Value::Object(ref mut map) = event_data { + map.remove("type"); + } + + let deserialized_event: TaskExecutionNotificationEvent = + serde_json::from_value(event_data).expect("Failed to deserialize"); + + match (original_event, deserialized_event) { + ( + TaskExecutionNotificationEvent::LineOutput { + task_id: id1, + output: out1, + }, + TaskExecutionNotificationEvent::LineOutput { + task_id: id2, + output: out2, + }, + ) => { + assert_eq!(id1, id2); + assert_eq!(out1, out2); + } + _ => panic!("Event types don't match after roundtrip"), + } + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index f1ca709cb318..80a537393e5b 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -2,7 +2,8 @@ use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; use serde_json::Value; use crate::agents::{ - sub_recipe_execution_tool::lib::execute_tasks, tool_execution::ToolCallResult, + sub_recipe_execution_tool::lib::execute_tasks, sub_recipe_execution_tool::types::ExecutionMode, + tool_execution::ToolCallResult, }; use mcp_core::protocol::JsonRpcMessage; use tokio::sync::mpsc; @@ -128,8 +129,8 @@ pub async fn run_tasks(execute_data: Value) -> ToolCallResult { let execute_data_clone = execute_data.clone(); let execution_mode = execute_data_clone .get("execution_mode") - .and_then(|v| v.as_str()) - .unwrap_or("sequential"); + .and_then(|v| serde_json::from_value::(v.clone()).ok()) + .unwrap_or_default(); match execute_tasks(execute_data, execution_mode, notification_tx).await { Ok(result) => { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs index 96176fc77930..639e4580b843 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -5,6 +5,10 @@ use std::sync::Arc; use tokio::sync::{mpsc, RwLock}; use tokio::time::{sleep, Duration, Instant}; +use crate::agents::sub_recipe_execution_tool::notification_events::{ + FailedTaskInfo, TaskCompletionStats, TaskExecutionNotificationEvent, TaskExecutionStats, + TaskInfo as EventTaskInfo, +}; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskResult, TaskStatus}; use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; use serde_json::Value; @@ -110,21 +114,21 @@ impl TaskExecutionTracker { pub async fn send_live_output(&self, task_id: &str, line: &str) { match self.display_mode { DisplayMode::SingleTaskOutput => { - // Send raw output data - let subscriber format it - if let Err(e) = self - .notifier - .try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: "notifications/message".to_string(), - params: Some(json!({ - "data": { - "type": "task_execution", - "subtype": "line_output", - "task_id": task_id, - "output": line - } - })), - })) { + let event = TaskExecutionNotificationEvent::line_output( + task_id.to_string(), + line.to_string(), + ); + + if let Err(e) = + self.notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": event.to_notification_data() + })), + })) + { tracing::warn!("Failed to send live output notification: {}", e); } } @@ -160,45 +164,44 @@ impl TaskExecutionTracker { let task_list: Vec<_> = tasks.values().collect(); let (total, pending, running, completed, failed) = count_by_status(&tasks); + let stats = TaskExecutionStats::new(total, pending, running, completed, failed); + + let event_tasks: Vec = task_list + .iter() + .map(|task_info| { + let now = Instant::now(); + EventTaskInfo { + id: task_info.task.id.clone(), + status: task_info.status.clone(), + duration_secs: task_info.start_time.map(|start| { + if let Some(end) = task_info.end_time { + end.duration_since(start).as_secs_f64() + } else { + now.duration_since(start).as_secs_f64() + } + }), + current_output: task_info.current_output.clone(), + task_type: task_info.task.task_type.clone(), + task_name: get_task_name(task_info).to_string(), + task_metadata: format_task_metadata(task_info), + error: task_info.error().cloned(), + result_data: task_info.data().cloned(), + } + }) + .collect(); + + let event = TaskExecutionNotificationEvent::tasks_update(stats, event_tasks); + if let Err(e) = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { jsonrpc: "2.0".to_string(), method: "notifications/message".to_string(), params: Some(json!({ - "data": { - "type": "task_execution", - "subtype": "tasks_update", - "stats": { - "total": total, - "pending": pending, - "running": running, - "completed": completed, - "failed": failed - }, - "tasks": task_list.iter().map(|task_info| { - let now = Instant::now(); - json!({ - "id": task_info.task.id, - "status": task_info.status, - "duration_secs": task_info.start_time.map(|start| { - if let Some(end) = task_info.end_time { - end.duration_since(start).as_secs_f64() - } else { - now.duration_since(start).as_secs_f64() - } - }), - "current_output": task_info.current_output, - "task_type": task_info.task.task_type, - "task_name": get_task_name(task_info), - "task_metadata": format_task_metadata(task_info), - "error": task_info.error(), - "result_data": task_info.data() - }) - }).collect::>() - } + "data": event.to_notification_data() })), - })) { + })) + { tracing::warn!("Failed to send tasks update notification: {}", e); } } @@ -236,38 +239,30 @@ impl TaskExecutionTracker { let tasks = self.tasks.read().await; let (total, _, _, completed, failed) = count_by_status(&tasks); - // Send structured summary data only - let failed_tasks: Vec<_> = tasks + let stats = TaskCompletionStats::new(total, completed, failed); + + let failed_tasks: Vec = tasks .values() .filter(|task_info| matches!(task_info.status, TaskStatus::Failed)) - .map(|task_info| { - json!({ - "id": task_info.task.id, - "name": get_task_name(task_info), - "error": task_info.error() - }) + .map(|task_info| FailedTaskInfo { + id: task_info.task.id.clone(), + name: get_task_name(task_info).to_string(), + error: task_info.error().cloned(), }) .collect(); + let event = TaskExecutionNotificationEvent::tasks_complete(stats, failed_tasks); + if let Err(e) = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { jsonrpc: "2.0".to_string(), method: "notifications/message".to_string(), params: Some(json!({ - "data": { - "type": "task_execution", - "subtype": "tasks_complete", - "stats": { - "total": total, - "completed": completed, - "failed": failed, - "success_rate": (completed as f64 / total as f64) * 100.0 - }, - "failed_tasks": failed_tasks - } + "data": event.to_notification_data() })), - })) { + })) + { tracing::warn!("Failed to send tasks complete notification: {}", e); } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index 43443804ddd9..60d022816680 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -15,7 +15,9 @@ pub async fn process_task( task: &Task, task_execution_tracker: Arc, ) -> TaskResult { - let timeout_in_seconds = task.timeout_in_seconds.unwrap_or(DEFAULT_TASK_TIMEOUT_SECONDS); + let timeout_in_seconds = task + .timeout_in_seconds + .unwrap_or(DEFAULT_TASK_TIMEOUT_SECONDS); let task_clone = task.clone(); let timeout_duration = Duration::from_secs(timeout_in_seconds); @@ -78,7 +80,7 @@ async fn get_task_result( fn build_command(task: &Task) -> Result<(Command, String), String> { let task_error = |field: &str| format!("Task {}: Missing {}", task.id, field); - + let mut output_identifier = task.id.clone(); let mut command = if task.task_type == "sub_recipe" { let sub_recipe_name = task @@ -186,11 +188,11 @@ fn spawn_output_reader( fn extract_json_from_line(line: &str) -> Option { let start = line.find('{')?; let end = line.rfind('}')?; - + if start >= end { return None; } - + let potential_json = &line[start..=end]; if serde_json::from_str::(potential_json).is_ok() { Some(potential_json.to_string()) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index e10c5adb889e..e558cee1f08b 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -6,6 +6,14 @@ use tokio::sync::mpsc; use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +#[serde(rename_all = "lowercase")] +pub enum ExecutionMode { + #[default] + Sequential, + Parallel, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Task { pub id: String, @@ -68,6 +76,17 @@ pub enum TaskStatus { Failed, } +impl std::fmt::Display for TaskStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TaskStatus::Pending => write!(f, "Pending"), + TaskStatus::Running => write!(f, "Running"), + TaskStatus::Completed => write!(f, "Completed"), + TaskStatus::Failed => write!(f, "Failed"), + } + } +} + #[derive(Debug, Clone)] pub struct TaskInfo { pub task: Task, diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index c2e39b15ade3..b86a69a8fcfe 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -9,7 +9,6 @@ pub fn get_task_name(task_info: &TaskInfo) -> &str { .unwrap_or(&task_info.task.id) } - pub fn count_by_status(tasks: &HashMap) -> (usize, usize, usize, usize, usize) { let total = tasks.len(); let (pending, running, completed, failed) = tasks.values().fold( diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 144c7831b32c..525433221ef2 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -281,13 +281,23 @@ impl Recipe { } } pub fn from_content(content: &str) -> Result { - let yaml_value = serde_yaml::from_str::(content) - .map_err(|_| anyhow::anyhow!("Unsupported format. Expected JSON or YAML."))?; - - // Check if there's a nested "recipe" key, otherwise use the root - let recipe_value = yaml_value.get("recipe").unwrap_or(&yaml_value); - - Ok(serde_yaml::from_value(recipe_value.clone())?) + if let Ok(json_value) = serde_json::from_str::(content) { + if let Some(nested_recipe) = json_value.get("recipe") { + Ok(serde_json::from_value(nested_recipe.clone())?) + } else { + Ok(serde_json::from_str(content)?) + } + } else if let Ok(yaml_value) = serde_yaml::from_str::(content) { + if let Some(nested_recipe) = yaml_value.get("recipe") { + Ok(serde_yaml::from_value(nested_recipe.clone())?) + } else { + Ok(serde_yaml::from_str(content)?) + } + } else { + Err(anyhow::anyhow!( + "Unsupported format. Expected JSON or YAML." + )) + } } } From 41c46a65073235b8153dede3d890218a63049d4a Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 18:43:33 +1000 Subject: [PATCH 21/43] added tests --- .../mod.rs} | 29 +- .../session/task_execution_display/tests.rs | 337 ++++++++++++++++++ .../{executor.rs => executor/mod.rs} | 3 + .../executor/tests.rs | 100 ++++++ .../{lib.rs => lib/mod.rs} | 3 + .../sub_recipe_execution_tool/lib/tests.rs | 216 +++++++++++ 6 files changed, 680 insertions(+), 8 deletions(-) rename crates/goose-cli/src/session/{task_execution_display.rs => task_execution_display/mod.rs} (89%) create mode 100644 crates/goose-cli/src/session/task_execution_display/tests.rs rename crates/goose/src/agents/sub_recipe_execution_tool/{executor.rs => executor/mod.rs} (99%) create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/executor/tests.rs rename crates/goose/src/agents/sub_recipe_execution_tool/{lib.rs => lib/mod.rs} (99%) create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/lib/tests.rs diff --git a/crates/goose-cli/src/session/task_execution_display.rs b/crates/goose-cli/src/session/task_execution_display/mod.rs similarity index 89% rename from crates/goose-cli/src/session/task_execution_display.rs rename to crates/goose-cli/src/session/task_execution_display/mod.rs index f8bad819f4d7..96d37b76d483 100644 --- a/crates/goose-cli/src/session/task_execution_display.rs +++ b/crates/goose-cli/src/session/task_execution_display/mod.rs @@ -5,6 +5,9 @@ use goose::agents::sub_recipe_execution_tool::notification_events::{ use serde_json::Value; use std::sync::atomic::{AtomicBool, Ordering}; +#[cfg(test)] +mod tests; + const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H"; const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H"; const CLEAR_TO_EOL: &str = "\x1b[K"; @@ -60,14 +63,24 @@ fn strip_ansi_codes(text: &str) -> String { while let Some(ch) = chars.next() { if ch == '\x1b' { - if chars.next() == Some('[') { - loop { - match chars.next() { - Some(c) if c.is_ascii_alphabetic() => break, - Some(_) => continue, - None => break, + if let Some(next_ch) = chars.next() { + if next_ch == '[' { + // This is an ANSI escape sequence, consume until alphabetic character + loop { + match chars.next() { + Some(c) if c.is_ascii_alphabetic() => break, + Some(_) => continue, + None => break, + } } + } else { + // Not an ANSI sequence, keep both characters + result.push(ch); + result.push(next_ch); } + } else { + // End of string after \x1b + result.push(ch); } } else { result.push(ch); @@ -130,7 +143,7 @@ fn format_tasks_update_from_event(event: &TaskExecutionNotificationEvent) -> Str sorted_tasks.sort_by(|a, b| a.id.cmp(&b.id)); for task in sorted_tasks { - display.push_str(&format_task_from_struct(&task)); + display.push_str(&format_task_display(&task)); } display.push_str(CLEAR_BELOW); @@ -172,7 +185,7 @@ fn format_tasks_complete_from_event(event: &TaskExecutionNotificationEvent) -> S } } -fn format_task_from_struct(task: &TaskInfo) -> String { +fn format_task_display(task: &TaskInfo) -> String { let mut task_display = String::new(); let status_icon = match task.status { diff --git a/crates/goose-cli/src/session/task_execution_display/tests.rs b/crates/goose-cli/src/session/task_execution_display/tests.rs new file mode 100644 index 000000000000..fb53285080d3 --- /dev/null +++ b/crates/goose-cli/src/session/task_execution_display/tests.rs @@ -0,0 +1,337 @@ +use super::*; +use goose::agents::sub_recipe_execution_tool::notification_events::{ + FailedTaskInfo, TaskCompletionStats, TaskExecutionStats, +}; +use serde_json::json; + +#[test] +fn test_strip_ansi_codes() { + assert_eq!(strip_ansi_codes("hello world"), "hello world"); + assert_eq!(strip_ansi_codes("\x1b[31mred text\x1b[0m"), "red text"); + assert_eq!( + strip_ansi_codes("\x1b[1;32mbold green\x1b[0m"), + "bold green" + ); + assert_eq!( + strip_ansi_codes("normal\x1b[33myellow\x1b[0mnormal"), + "normalyellownormal" + ); + assert_eq!(strip_ansi_codes("\x1bhello"), "\x1bhello"); + assert_eq!(strip_ansi_codes("hello\x1b"), "hello\x1b"); + assert_eq!(strip_ansi_codes(""), ""); +} + +#[test] +fn test_truncate_with_ellipsis() { + assert_eq!(truncate_with_ellipsis("hello", 10), "hello"); + assert_eq!(truncate_with_ellipsis("hello", 5), "hello"); + assert_eq!(truncate_with_ellipsis("hello world", 8), "hello..."); + assert_eq!(truncate_with_ellipsis("hello", 3), "..."); + assert_eq!(truncate_with_ellipsis("hello", 2), "..."); + assert_eq!(truncate_with_ellipsis("hello", 1), "..."); + assert_eq!(truncate_with_ellipsis("", 5), ""); +} + +#[test] +fn test_process_output_for_display() { + assert_eq!(process_output_for_display("hello world"), "hello world"); + assert_eq!( + process_output_for_display("line1\nline2"), + "line1 ... line2" + ); + + let input = "line1\nline2\nline3\nline4"; + let result = process_output_for_display(input); + assert_eq!(result, "line3 ... line4"); + + let long_line = "a".repeat(150); + let result = process_output_for_display(&long_line); + assert!(result.len() <= 100); + assert!(result.ends_with("...")); + + let ansi_output = "\x1b[31mred line 1\x1b[0m\n\x1b[32mgreen line 2\x1b[0m"; + let result = process_output_for_display(ansi_output); + assert_eq!(result, "red line 1 ... green line 2"); + + assert_eq!(process_output_for_display(""), ""); +} + +#[test] +fn test_format_result_data_for_display() { + let string_val = json!("hello world"); + assert_eq!(format_result_data_for_display(&string_val), "hello world"); + + let ansi_string = json!("\x1b[31mred text\x1b[0m"); + assert_eq!(format_result_data_for_display(&ansi_string), "red text"); + + assert_eq!(format_result_data_for_display(&json!(true)), "true"); + assert_eq!(format_result_data_for_display(&json!(false)), "false"); + assert_eq!(format_result_data_for_display(&json!(42)), "42"); + assert_eq!(format_result_data_for_display(&json!(3.14)), "3.14"); + assert_eq!(format_result_data_for_display(&json!(null)), "null"); + + let partial_obj = json!({ + "partial_output": "some output", + "other_field": "ignored" + }); + assert_eq!( + format_result_data_for_display(&partial_obj), + "Partial output: some output" + ); + + let obj = json!({"key": "value", "num": 42}); + let result = format_result_data_for_display(&obj); + assert!(result.contains("key")); + assert!(result.contains("value")); + + let arr = json!([1, 2, 3]); + let result = format_result_data_for_display(&arr); + assert!(result.contains("1")); + assert!(result.contains("2")); + assert!(result.contains("3")); +} + +#[test] +fn test_format_task_execution_notification_line_output() { + let _event = TaskExecutionNotificationEvent::LineOutput { + task_id: "task-1".to_string(), + output: "Hello World".to_string(), + }; + + let data = json!({ + "subtype": "line_output", + "task_id": "task-1", + "output": "Hello World" + }); + + let result = format_task_execution_notification(&data); + assert!(result.is_some()); + + let (formatted, second, third) = result.unwrap(); + assert_eq!(formatted, "Hello World\n"); + assert_eq!(second, None); + assert_eq!(third, Some("task_execution".to_string())); +} + +#[test] +fn test_format_task_execution_notification_invalid_data() { + let invalid_data = json!({ + "invalid": "structure" + }); + + let result = format_task_execution_notification(&invalid_data); + assert_eq!(result, None); + + let incomplete_data = json!({ + "subtype": "line_output" + }); + + let result = format_task_execution_notification(&incomplete_data); + assert_eq!(result, None); +} + +#[test] +fn test_format_tasks_update_from_event() { + INITIAL_SHOWN.store(false, Ordering::SeqCst); + + let stats = TaskExecutionStats::new(3, 1, 1, 1, 0); + let tasks = vec![ + TaskInfo { + id: "task-1".to_string(), + status: TaskStatus::Running, + duration_secs: Some(1.5), + current_output: "Processing...".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "test-task".to_string(), + task_metadata: "param=value".to_string(), + error: None, + result_data: None, + }, + TaskInfo { + id: "task-2".to_string(), + status: TaskStatus::Completed, + duration_secs: Some(2.3), + current_output: "".to_string(), + task_type: "text_instruction".to_string(), + task_name: "another-task".to_string(), + task_metadata: "".to_string(), + error: None, + result_data: Some(json!({"result": "success"})), + }, + ]; + + let event = TaskExecutionNotificationEvent::TasksUpdate { stats, tasks }; + let result = format_tasks_update_from_event(&event); + + assert!(result.contains("🎯 Task Execution Dashboard")); + assert!(result.contains("═══════════════════════════")); + assert!(result.contains("📊 Progress: 3 total")); + assert!(result.contains("⏳ 1 pending")); + assert!(result.contains("🏃 1 running")); + assert!(result.contains("✅ 1 completed")); + assert!(result.contains("❌ 0 failed")); + assert!(result.contains("🏃 test-task")); + assert!(result.contains("✅ another-task")); + assert!(result.contains("📋 Parameters: param=value")); + assert!(result.contains("⏱️ 1.5s")); + assert!(result.contains("💬 Processing...")); + + let result2 = format_tasks_update_from_event(&event); + assert!(!result2.contains("🎯 Task Execution Dashboard")); + assert!(result2.contains(MOVE_TO_PROGRESS_LINE)); +} + +#[test] +fn test_format_tasks_complete_from_event() { + let stats = TaskCompletionStats::new(5, 4, 1); + let failed_tasks = vec![FailedTaskInfo { + id: "task-3".to_string(), + name: "failed-task".to_string(), + error: Some("Connection timeout".to_string()), + }]; + + let event = TaskExecutionNotificationEvent::TasksComplete { + stats, + failed_tasks, + }; + let result = format_tasks_complete_from_event(&event); + + assert!(result.contains("Execution Complete!")); + assert!(result.contains("═══════════════════════")); + assert!(result.contains("Total Tasks: 5")); + assert!(result.contains("✅ Completed: 4")); + assert!(result.contains("❌ Failed: 1")); + assert!(result.contains("📈 Success Rate: 80.0%")); + assert!(result.contains("❌ Failed Tasks:")); + assert!(result.contains("• failed-task")); + assert!(result.contains("Error: Connection timeout")); + assert!(result.contains("📝 Generating summary...")); +} + +#[test] +fn test_format_tasks_complete_from_event_no_failures() { + let stats = TaskCompletionStats::new(3, 3, 0); + let failed_tasks = vec![]; + + let event = TaskExecutionNotificationEvent::TasksComplete { + stats, + failed_tasks, + }; + let result = format_tasks_complete_from_event(&event); + + assert!(!result.contains("❌ Failed Tasks:")); + assert!(result.contains("📈 Success Rate: 100.0%")); + assert!(result.contains("❌ Failed: 0")); +} + +#[test] +fn test_format_task_display_running() { + let task = TaskInfo { + id: "task-1".to_string(), + status: TaskStatus::Running, + duration_secs: Some(1.5), + current_output: "Processing data...\nAlmost done...".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "data-processor".to_string(), + task_metadata: "input=file.txt,output=result.json".to_string(), + error: None, + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(result.contains("🏃 data-processor (sub_recipe)")); + assert!(result.contains("📋 Parameters: input=file.txt,output=result.json")); + assert!(result.contains("⏱️ 1.5s")); + assert!(result.contains("💬 Processing data... ... Almost done...")); +} + +#[test] +fn test_format_task_display_completed() { + let task = TaskInfo { + id: "task-2".to_string(), + status: TaskStatus::Completed, + duration_secs: Some(3.2), + current_output: "".to_string(), + task_type: "text_instruction".to_string(), + task_name: "analyzer".to_string(), + task_metadata: "".to_string(), + error: None, + result_data: Some(json!({"status": "success", "count": 42})), + }; + + let result = format_task_display(&task); + + assert!(result.contains("✅ analyzer (text_instruction)")); + assert!(result.contains("⏱️ 3.2s")); + assert!(!result.contains("📋 Parameters")); + assert!(result.contains("📄")); +} + +#[test] +fn test_format_task_display_failed() { + let task = TaskInfo { + id: "task-3".to_string(), + status: TaskStatus::Failed, + duration_secs: None, + current_output: "".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "failing-task".to_string(), + task_metadata: "".to_string(), + error: Some( + "Network connection failed after multiple retries. The server is unreachable." + .to_string(), + ), + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(result.contains("❌ failing-task (sub_recipe)")); + assert!(!result.contains("⏱️")); + assert!(result.contains("⚠️")); + assert!(result.contains("Network connection failed after multiple retries")); +} + +#[test] +fn test_format_task_display_pending() { + let task = TaskInfo { + id: "task-4".to_string(), + status: TaskStatus::Pending, + duration_secs: None, + current_output: "".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "waiting-task".to_string(), + task_metadata: "priority=high".to_string(), + error: None, + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(result.contains("⏳ waiting-task (sub_recipe)")); + assert!(result.contains("📋 Parameters: priority=high")); + assert!(!result.contains("⏱️")); + assert!(!result.contains("💬")); + assert!(!result.contains("📄")); + assert!(!result.contains("⚠️")); +} + +#[test] +fn test_format_task_display_empty_current_output() { + let task = TaskInfo { + id: "task-5".to_string(), + status: TaskStatus::Running, + duration_secs: Some(0.5), + current_output: " \n\t \n ".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "quiet-task".to_string(), + task_metadata: "".to_string(), + error: None, + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(!result.contains("💬")); +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs similarity index 99% rename from crates/goose/src/agents/sub_recipe_execution_tool/executor.rs rename to crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs index 674bd13ae5e9..c6183b7c137b 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs @@ -13,6 +13,9 @@ use crate::agents::sub_recipe_execution_tool::task_execution_tracker::{ use crate::agents::sub_recipe_execution_tool::tasks::process_task; use crate::agents::sub_recipe_execution_tool::workers::spawn_worker; +#[cfg(test)] +mod tests; + const EXECUTION_STATUS_COMPLETED: &str = "completed"; pub async fn execute_single_task( diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor/tests.rs new file mode 100644 index 000000000000..76385b87ef37 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor/tests.rs @@ -0,0 +1,100 @@ +use super::{calculate_stats, create_empty_response, create_error_response}; +use crate::agents::sub_recipe_execution_tool::lib::{TaskResult, TaskStatus}; +use serde_json::json; + +fn create_test_task_result(task_id: &str, status: TaskStatus) -> TaskResult { + let is_failed = matches!(status, TaskStatus::Failed); + TaskResult { + task_id: task_id.to_string(), + status, + data: Some(json!({"output": "test output"})), + error: if is_failed { + Some("Test error".to_string()) + } else { + None + }, + } +} + +#[test] +fn test_calculate_stats() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed), + create_test_task_result("task2", TaskStatus::Completed), + create_test_task_result("task3", TaskStatus::Failed), + create_test_task_result("task4", TaskStatus::Completed), + ]; + + let stats = calculate_stats(&results, 1500); + + assert_eq!(stats.total_tasks, 4); + assert_eq!(stats.completed, 3); + assert_eq!(stats.failed, 1); + assert_eq!(stats.execution_time_ms, 1500); +} + +#[test] +fn test_calculate_stats_empty_results() { + let results = vec![]; + let stats = calculate_stats(&results, 0); + + assert_eq!(stats.total_tasks, 0); + assert_eq!(stats.completed, 0); + assert_eq!(stats.failed, 0); + assert_eq!(stats.execution_time_ms, 0); +} + +#[test] +fn test_calculate_stats_all_completed() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed), + create_test_task_result("task2", TaskStatus::Completed), + ]; + + let stats = calculate_stats(&results, 800); + + assert_eq!(stats.total_tasks, 2); + assert_eq!(stats.completed, 2); + assert_eq!(stats.failed, 0); + assert_eq!(stats.execution_time_ms, 800); +} + +#[test] +fn test_calculate_stats_all_failed() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Failed), + create_test_task_result("task2", TaskStatus::Failed), + ]; + + let stats = calculate_stats(&results, 1200); + + assert_eq!(stats.total_tasks, 2); + assert_eq!(stats.completed, 0); + assert_eq!(stats.failed, 2); + assert_eq!(stats.execution_time_ms, 1200); +} + +#[test] +fn test_create_empty_response() { + let response = create_empty_response(); + + assert_eq!(response.status, "completed"); + assert_eq!(response.results.len(), 0); + assert_eq!(response.stats.total_tasks, 0); + assert_eq!(response.stats.completed, 0); + assert_eq!(response.stats.failed, 0); + assert_eq!(response.stats.execution_time_ms, 0); +} + +#[test] +fn test_create_error_response() { + let error_msg = "Test error message"; + let response = create_error_response(error_msg.to_string()); + + assert_eq!(response.status, "failed"); + assert_eq!(response.results.len(), 0); + assert_eq!(response.stats.total_tasks, 0); + assert_eq!(response.stats.completed, 0); + assert_eq!(response.stats.failed, 1); + assert_eq!(response.stats.execution_time_ms, 0); +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs similarity index 99% rename from crates/goose/src/agents/sub_recipe_execution_tool/lib.rs rename to crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs index 4cded2c71000..0515bf37bb44 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs @@ -6,6 +6,9 @@ pub use crate::agents::sub_recipe_execution_tool::types::{ TaskStatus, }; +#[cfg(test)] +mod tests; + use mcp_core::protocol::JsonRpcMessage; use serde_json::Value; use tokio::sync::mpsc; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/tests.rs new file mode 100644 index 000000000000..957b11274279 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/tests.rs @@ -0,0 +1,216 @@ +use super::{ + extract_failed_tasks, format_error_summary, format_failed_task_error, get_task_description, + handle_response, +}; +use crate::agents::sub_recipe_execution_tool::lib::{ + ExecutionResponse, ExecutionStats, TaskResult, TaskStatus, +}; +use serde_json::json; + +fn create_test_task_result(task_id: &str, status: TaskStatus, error: Option) -> TaskResult { + TaskResult { + task_id: task_id.to_string(), + status, + data: Some(json!({"partial_output": "test output"})), + error, + } +} + +fn create_test_execution_response( + results: Vec, + failed_count: usize, +) -> ExecutionResponse { + ExecutionResponse { + status: "completed".to_string(), + results: results.clone(), + stats: ExecutionStats { + total_tasks: results.len(), + completed: results.len() - failed_count, + failed: failed_count, + execution_time_ms: 1000, + }, + } +} + +#[test] +fn test_extract_failed_tasks() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result( + "task2", + TaskStatus::Failed, + Some("Error message".to_string()), + ), + create_test_task_result("task3", TaskStatus::Completed, None), + create_test_task_result( + "task4", + TaskStatus::Failed, + Some("Another error".to_string()), + ), + ]; + + let failed_tasks = extract_failed_tasks(&results); + + assert_eq!(failed_tasks.len(), 2); + assert!(failed_tasks[0].contains("task2")); + assert!(failed_tasks[0].contains("Error message")); + assert!(failed_tasks[1].contains("task4")); + assert!(failed_tasks[1].contains("Another error")); +} + +#[test] +fn test_extract_failed_tasks_empty() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result("task2", TaskStatus::Completed, None), + ]; + + let failed_tasks = extract_failed_tasks(&results); + + assert_eq!(failed_tasks.len(), 0); +} + +#[test] +fn test_format_failed_task_error_with_error_message() { + let result = create_test_task_result( + "task1", + TaskStatus::Failed, + Some("Test error message".to_string()), + ); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("task1")); + assert!(formatted.contains("Test error message")); + assert!(formatted.contains("test output")); + assert!(formatted.contains("ID: task1")); +} + +#[test] +fn test_format_failed_task_error_without_error_message() { + let result = create_test_task_result("task2", TaskStatus::Failed, None); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("task2")); + assert!(formatted.contains("Unknown error")); + assert!(formatted.contains("test output")); +} + +#[test] +fn test_format_failed_task_error_empty_partial_output() { + let mut result = + create_test_task_result("task3", TaskStatus::Failed, Some("Error".to_string())); + result.data = Some(json!({"partial_output": ""})); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("No output captured")); +} + +#[test] +fn test_format_failed_task_error_no_partial_output() { + let mut result = + create_test_task_result("task4", TaskStatus::Failed, Some("Error".to_string())); + result.data = Some(json!({})); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("No output captured")); +} + +#[test] +fn test_format_failed_task_error_no_data() { + let mut result = + create_test_task_result("task5", TaskStatus::Failed, Some("Error".to_string())); + result.data = None; + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("No output captured")); +} + +#[test] +fn test_format_error_summary() { + let failed_tasks = vec![ + "Task 'task1': Error 1\nOutput: output1".to_string(), + "Task 'task2': Error 2\nOutput: output2".to_string(), + ]; + + let summary = format_error_summary(2, 5, failed_tasks); + + assert_eq!(summary, "2/5 tasks failed:\nTask 'task1': Error 1\nOutput: output1\nTask 'task2': Error 2\nOutput: output2"); +} + +#[test] +fn test_format_error_summary_single_failure() { + let failed_tasks = vec!["Task 'task1': Error\nOutput: output".to_string()]; + + let summary = format_error_summary(1, 3, failed_tasks); + + assert_eq!( + summary, + "1/3 tasks failed:\nTask 'task1': Error\nOutput: output" + ); +} + +#[test] +fn test_handle_response_success() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result("task2", TaskStatus::Completed, None), + ]; + let response = create_test_execution_response(results, 0); + + let result = handle_response(response); + + assert!(result.is_ok()); + let value = result.unwrap(); + assert_eq!(value["status"], "completed"); + assert_eq!(value["stats"]["failed"], 0); +} + +#[test] +fn test_handle_response_with_failures() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result("task2", TaskStatus::Failed, Some("Test error".to_string())), + ]; + let response = create_test_execution_response(results, 1); + + let result = handle_response(response); + + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.contains("1/2 tasks failed")); + assert!(error.contains("task2")); + assert!(error.contains("Test error")); +} + +#[test] +fn test_handle_response_all_failures() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Failed, Some("Error 1".to_string())), + create_test_task_result("task2", TaskStatus::Failed, Some("Error 2".to_string())), + ]; + let response = create_test_execution_response(results, 2); + + let result = handle_response(response); + + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.contains("2/2 tasks failed")); + assert!(error.contains("task1")); + assert!(error.contains("task2")); + assert!(error.contains("Error 1")); + assert!(error.contains("Error 2")); +} + +#[test] +fn test_get_task_description() { + let result = create_test_task_result("test_task_123", TaskStatus::Completed, None); + + let description = get_task_description(&result); + + assert_eq!(description, "ID: test_task_123"); +} From 2c0f24c450f980aba4559522185f76c6a2ffe846 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 19:32:59 +1000 Subject: [PATCH 22/43] revert some unexpected deletion --- crates/goose-cli/src/session/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 5ad837991837..61e434568e10 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -1015,6 +1015,9 @@ impl Session { } }; (formatted, Some(subagent_id.to_string()), Some(notification_type.to_string())) + } else if let Some(Value::String(output)) = o.get("output") { + // Fallback for other MCP notification types + (output.to_owned(), None, None) } else if let Some(result) = format_task_execution_notification(data) { result } else { From 41fc7ca86e8c52c0ef86331d298c1411734db4e4 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 19:45:20 +1000 Subject: [PATCH 23/43] removed initial worker size --- .../sub_recipe_execute_task_tool.rs | 3 --- .../goose/src/agents/sub_recipe_execution_tool/types.rs | 8 -------- 2 files changed, 11 deletions(-) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 80a537393e5b..10c35e7b2e32 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -103,9 +103,6 @@ Pre-created Task Based: "properties": { "max_workers": { "type": "number" - }, - "initial_workers": { - "type": "number" } } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index e558cee1f08b..ab0b496db469 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -128,30 +128,22 @@ impl SharedState { pub struct Config { #[serde(default = "default_max_workers")] pub max_workers: usize, - #[serde(default = "default_initial_workers")] - pub initial_workers: usize, } impl Default for Config { fn default() -> Self { Self { max_workers: default_max_workers(), - initial_workers: default_initial_workers(), } } } const DEFAULT_MAX_WORKERS: usize = 10; -const DEFAULT_INITIAL_WORKERS: usize = 2; fn default_max_workers() -> usize { DEFAULT_MAX_WORKERS } -fn default_initial_workers() -> usize { - DEFAULT_INITIAL_WORKERS -} - #[derive(Debug, Serialize)] pub struct ExecutionStats { pub total_tasks: usize, From 353c52a3f0848ff90bee083c2794eb57ed9f6078 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 19:55:43 +1000 Subject: [PATCH 24/43] removed worker config --- .claude/settings.local.json | 9 +++++++++ .../sub_recipe_execution_tool/executor/mod.rs | 7 +++---- .../sub_recipe_execution_tool/lib/mod.rs | 11 ++--------- .../sub_recipe_execute_task_tool.rs | 8 -------- .../agents/sub_recipe_execution_tool/types.rs | 19 ------------------- 5 files changed, 14 insertions(+), 40 deletions(-) create mode 100644 .claude/settings.local.json diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 000000000000..8a954032a156 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,9 @@ +{ + "permissions": { + "allow": [ + "Bash(grep:*)", + "Bash(awk:*)" + ], + "deny": [] + } +} \ No newline at end of file diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs index c6183b7c137b..bb73b73e9c54 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs @@ -5,7 +5,7 @@ use tokio::sync::mpsc; use tokio::time::Instant; use crate::agents::sub_recipe_execution_tool::lib::{ - Config, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, + ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; use crate::agents::sub_recipe_execution_tool::task_execution_tracker::{ DisplayMode, TaskExecutionTracker, @@ -17,6 +17,7 @@ use crate::agents::sub_recipe_execution_tool::workers::spawn_worker; mod tests; const EXECUTION_STATUS_COMPLETED: &str = "completed"; +const DEFAULT_MAX_WORKERS: usize = 10; pub async fn execute_single_task( task: &Task, @@ -41,7 +42,6 @@ pub async fn execute_single_task( pub async fn execute_tasks_in_parallel( tasks: Vec, - config: Config, notifier: mpsc::Sender, ) -> ExecutionResponse { let task_execution_tracker = Arc::new(TaskExecutionTracker::new( @@ -67,8 +67,7 @@ pub async fn execute_tasks_in_parallel( let shared_state = create_shared_state(task_rx, result_tx, task_execution_tracker.clone()); - // Simple static worker allocation - no dynamic scaling needed - let worker_count = std::cmp::min(task_count, config.max_workers); + let worker_count = std::cmp::min(task_count, DEFAULT_MAX_WORKERS); let mut worker_handles = Vec::new(); for i in 0..worker_count { let handle = spawn_worker(shared_state.clone(), i); diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs index 0515bf37bb44..b6a0361322cb 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs @@ -2,7 +2,7 @@ use crate::agents::sub_recipe_execution_tool::executor::{ execute_single_task, execute_tasks_in_parallel, }; pub use crate::agents::sub_recipe_execution_tool::types::{ - Config, ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, + ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; @@ -22,13 +22,6 @@ pub async fn execute_tasks( serde_json::from_value(input.get("tasks").ok_or("Missing tasks field")?.clone()) .map_err(|e| format!("Failed to parse tasks: {}", e))?; - let config: Config = if let Some(config_value) = input.get("config") { - serde_json::from_value(config_value.clone()) - .map_err(|e| format!("Failed to parse config: {}", e))? - } else { - Config::default() - }; - let task_count = tasks.len(); match execution_mode { ExecutionMode::Sequential => { @@ -41,7 +34,7 @@ pub async fn execute_tasks( } ExecutionMode::Parallel => { let response: ExecutionResponse = - execute_tasks_in_parallel(tasks, config, notifier).await; + execute_tasks_in_parallel(tasks, notifier).await; handle_response(response) } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 10c35e7b2e32..7a37fd51c4e3 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -97,14 +97,6 @@ Pre-created Task Based: "required": ["id", "payload"] }, "description": "The tasks to run in parallel" - }, - "config": { - "type": "object", - "properties": { - "max_workers": { - "type": "number" - } - } } }, "required": ["tasks"] diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index ab0b496db469..74cd8b6e6436 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -124,25 +124,6 @@ impl SharedState { } } -#[derive(Debug, Clone, Deserialize)] -pub struct Config { - #[serde(default = "default_max_workers")] - pub max_workers: usize, -} - -impl Default for Config { - fn default() -> Self { - Self { - max_workers: default_max_workers(), - } - } -} - -const DEFAULT_MAX_WORKERS: usize = 10; - -fn default_max_workers() -> usize { - DEFAULT_MAX_WORKERS -} #[derive(Debug, Serialize)] pub struct ExecutionStats { From 17a52138b5a9a4db042674032c4a5422a2ce5e52 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 20:27:35 +1000 Subject: [PATCH 25/43] fixed fmt --- .claude/settings.local.json | 9 --------- .gitignore | 3 +++ .../src/agents/sub_recipe_execution_tool/lib/mod.rs | 6 ++---- .../goose/src/agents/sub_recipe_execution_tool/types.rs | 1 - 4 files changed, 5 insertions(+), 14 deletions(-) delete mode 100644 .claude/settings.local.json diff --git a/.claude/settings.local.json b/.claude/settings.local.json deleted file mode 100644 index 8a954032a156..000000000000 --- a/.claude/settings.local.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "permissions": { - "allow": [ - "Bash(grep:*)", - "Bash(awk:*)" - ], - "deny": [] - } -} \ No newline at end of file diff --git a/.gitignore b/.gitignore index caab83d726c7..41f629f09366 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,9 @@ ui/desktop/src/bin/goose_llm.dll # Hermit .hermit/ +# Claude +.claude + debug_*.txt # Docs diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs index b6a0361322cb..a4aedec568c7 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs @@ -2,8 +2,7 @@ use crate::agents::sub_recipe_execution_tool::executor::{ execute_single_task, execute_tasks_in_parallel, }; pub use crate::agents::sub_recipe_execution_tool::types::{ - ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, - TaskStatus, + ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; #[cfg(test)] @@ -33,8 +32,7 @@ pub async fn execute_tasks( } } ExecutionMode::Parallel => { - let response: ExecutionResponse = - execute_tasks_in_parallel(tasks, notifier).await; + let response: ExecutionResponse = execute_tasks_in_parallel(tasks, notifier).await; handle_response(response) } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index 74cd8b6e6436..ea31746032d7 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -124,7 +124,6 @@ impl SharedState { } } - #[derive(Debug, Serialize)] pub struct ExecutionStats { pub total_tasks: usize, From 4f52ab1e1a3942c038896a93985bc51fb69861b3 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 11 Jul 2025 16:58:19 +1000 Subject: [PATCH 26/43] renamed types.rs to task_types.rs --- crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs | 2 +- crates/goose/src/agents/sub_recipe_execution_tool/mod.rs | 2 +- .../agents/sub_recipe_execution_tool/notification_events.rs | 2 +- .../sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs | 4 ++-- .../sub_recipe_execution_tool/task_execution_tracker.rs | 4 +++- .../sub_recipe_execution_tool/{types.rs => task_types.rs} | 0 crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs | 2 +- .../goose/src/agents/sub_recipe_execution_tool/utils/mod.rs | 2 +- .../goose/src/agents/sub_recipe_execution_tool/utils/tests.rs | 2 +- crates/goose/src/agents/sub_recipe_execution_tool/workers.rs | 2 +- 10 files changed, 12 insertions(+), 10 deletions(-) rename crates/goose/src/agents/sub_recipe_execution_tool/{types.rs => task_types.rs} (100%) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs index a4aedec568c7..746e4ecb151c 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs @@ -1,7 +1,7 @@ use crate::agents::sub_recipe_execution_tool::executor::{ execute_single_task, execute_tasks_in_parallel, }; -pub use crate::agents::sub_recipe_execution_tool::types::{ +pub use crate::agents::sub_recipe_execution_tool::task_types::{ ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs index 267caab8f115..b6363ba20d14 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs @@ -3,7 +3,7 @@ pub mod lib; pub mod notification_events; pub mod sub_recipe_execute_task_tool; mod task_execution_tracker; +mod task_types; mod tasks; -mod types; pub mod utils; mod workers; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs b/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs index 97a4576b2d99..2a6134ea1a55 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs @@ -1,4 +1,4 @@ -use crate::agents::sub_recipe_execution_tool::types::TaskStatus; +use crate::agents::sub_recipe_execution_tool::task_types::TaskStatus; use serde::{Deserialize, Serialize}; use serde_json::Value; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 7a37fd51c4e3..62054cca18e6 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -2,8 +2,8 @@ use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; use serde_json::Value; use crate::agents::{ - sub_recipe_execution_tool::lib::execute_tasks, sub_recipe_execution_tool::types::ExecutionMode, - tool_execution::ToolCallResult, + sub_recipe_execution_tool::lib::execute_tasks, + sub_recipe_execution_tool::task_types::ExecutionMode, tool_execution::ToolCallResult, }; use mcp_core::protocol::JsonRpcMessage; use tokio::sync::mpsc; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs index 639e4580b843..b456fd77424f 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -9,7 +9,9 @@ use crate::agents::sub_recipe_execution_tool::notification_events::{ FailedTaskInfo, TaskCompletionStats, TaskExecutionNotificationEvent, TaskExecutionStats, TaskInfo as EventTaskInfo, }; -use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskResult, TaskStatus}; +use crate::agents::sub_recipe_execution_tool::task_types::{ + Task, TaskInfo, TaskResult, TaskStatus, +}; use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; use serde_json::Value; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs similarity index 100% rename from crates/goose/src/agents/sub_recipe_execution_tool/types.rs rename to crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index 60d022816680..fb41b0632d18 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -7,7 +7,7 @@ use tokio::process::Command; use tokio::time::timeout; use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; -use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult, TaskStatus}; +use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskResult, TaskStatus}; const DEFAULT_TASK_TIMEOUT_SECONDS: u64 = 300; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index b86a69a8fcfe..1ead865e6571 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use crate::agents::sub_recipe_execution_tool::types::{TaskInfo, TaskStatus}; +use crate::agents::sub_recipe_execution_tool::task_types::{TaskInfo, TaskStatus}; pub fn get_task_name(task_info: &TaskInfo) -> &str { task_info diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs index 618026a7f003..f799b699aaca 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -1,4 +1,4 @@ -use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskStatus}; +use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskInfo, TaskStatus}; use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; use serde_json::json; use std::collections::HashMap; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index 35e9f6d22219..fefbf0eb82b2 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -1,5 +1,5 @@ +use crate::agents::sub_recipe_execution_tool::task_types::{SharedState, Task}; use crate::agents::sub_recipe_execution_tool::tasks::process_task; -use crate::agents::sub_recipe_execution_tool::types::{SharedState, Task}; use std::sync::Arc; async fn receive_task(state: &SharedState) -> Option { From 8f4f55f5d7dff2a633beacce7bbfda782686bc0f Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Mon, 14 Jul 2025 13:26:31 -0700 Subject: [PATCH 27/43] dynamic tool --- crates/goose-cli/src/session/builder.rs | 2 ++ crates/goose/src/agents/agent.rs | 8 ++++++-- crates/goose/src/agents/recipe_tools/mod.rs | 1 + crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs | 4 ++-- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index b5d3da1d7f45..0377a36e905e 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -187,6 +187,8 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { // Create the agent let agent: Agent = Agent::new(); + + // Sub-recipes if let Some(sub_recipes) = session_config.sub_recipes { agent.add_sub_recipes(sub_recipes).await; } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 38a488a9781a..b9592bff808a 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -9,6 +9,9 @@ use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; use mcp_core::protocol::JsonRpcMessage; use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME}; +use crate::agents::recipe_tools::dynamic_task_tools::{ + create_dynamic_task, create_dynamic_task_tool, DYNAMIC_TASK_TOOL_NAME_PREFIX, +}; use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{ self, SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, }; @@ -53,7 +56,6 @@ use super::final_output_tool::FinalOutputTool; use super::platform_tools; use super::router_tools; use super::subagent_manager::SubAgentManager; -use super::subagent_tools; use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; const DEFAULT_MAX_TURNS: u32 = 1000; @@ -295,6 +297,8 @@ impl Agent { .await } else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME { sub_recipe_execute_task_tool::run_tasks(tool_call.arguments.clone()).await + } else if tool_call.name == DYNAMIC_TASK_TOOL_NAME_PREFIX { + create_dynamic_task(tool_call.arguments.clone()).await } else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { // Check if the tool is read_resource and handle it separately ToolCallResult::from( @@ -559,7 +563,7 @@ impl Agent { // Add subagent tool (only if ALPHA_FEATURES is enabled) let config = Config::global(); if config.get_param::("ALPHA_FEATURES").unwrap_or(false) { - prefixed_tools.push(subagent_tools::run_task_subagent_tool()); + prefixed_tools.push(create_dynamic_task_tool()); } // Add resource tools if supported diff --git a/crates/goose/src/agents/recipe_tools/mod.rs b/crates/goose/src/agents/recipe_tools/mod.rs index 90603c88488e..6e6f28a80310 100644 --- a/crates/goose/src/agents/recipe_tools/mod.rs +++ b/crates/goose/src/agents/recipe_tools/mod.rs @@ -1,2 +1,3 @@ +pub mod dynamic_task_tools; pub mod param_utils; pub mod sub_recipe_tools; diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index fd0edaeaecd1..bc0347dfd509 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -11,8 +11,8 @@ use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubReci use super::param_utils::prepare_command_params; pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; -const EXECUTION_MODE_PARALLEL: &str = "parallel"; -const EXECUTION_MODE_SEQUENTIAL: &str = "sequential"; +pub const EXECUTION_MODE_PARALLEL: &str = "parallel"; +pub const EXECUTION_MODE_SEQUENTIAL: &str = "sequential"; pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { let input_schema = get_input_schema(sub_recipe).unwrap(); From 6c6113d2bd02d7b2baa4beed07b9646046a6e8f9 Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Mon, 14 Jul 2025 13:27:11 -0700 Subject: [PATCH 28/43] add dynamic task tool --- .../agents/recipe_tools/dynamic_task_tools.rs | 138 ++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs diff --git a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs new file mode 100644 index 000000000000..77a66bee3fac --- /dev/null +++ b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs @@ -0,0 +1,138 @@ +// ======================================= +// Module: Dynamic Task Tools +// Handles creation of tasks dynamically without sub-recipes +// ======================================= +use crate::agents::recipe_tools::sub_recipe_tools::{ + EXECUTION_MODE_PARALLEL, EXECUTION_MODE_SEQUENTIAL, +}; +use crate::agents::sub_recipe_execution_tool::lib::Task; +use crate::agents::tool_execution::ToolCallResult; +use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; +use serde_json::{json, Value}; + +pub const DYNAMIC_TASK_TOOL_NAME_PREFIX: &str = "dynamic_task__create_task"; + +pub fn create_dynamic_task_tool() -> Tool { + Tool::new( + format!("{}", DYNAMIC_TASK_TOOL_NAME_PREFIX), + format!( + "Creates a dynamic task object(s) based on textual instructions. \ + Provide an array of parameter sets in the 'task_parameters' field:\n\ + - For a single task: provide an array with one parameter set\n\ + - For multiple tasks: provide an array with multiple parameter sets, each with different values\n\n\ + Each task will run the same text instruction but with different parameter values. \ + This is useful when you need to execute the same instruction multiple times with varying inputs. \ + After creating the task list, pass it to the task executor to run all tasks." + ), + json!({ + "type": "object", + "properties": { + "task_parameters": { + "type": "array", + "description": "Array of parameter sets for creating tasks. \ + For a single task, provide an array with one element. \ + For multiple tasks, provide an array with multiple elements, each with different parameter values. \ + If there is no parameter set, provide an empty array.", + "items": { + "type": "object", + "properties": { + "text_instruction": { + "type": "string", + "description": "The text instruction to execute" + }, + "timeout_seconds": { + "type": "integer", + "description": "Optional timeout for the task in seconds (default: 300)", + "minimum": 1 + } + }, + "required": ["text_instruction"] + } + } + } + }), + Some(ToolAnnotations { + title: Some(format!("Dynamic Task Creation")), + read_only_hint: false, + destructive_hint: true, + idempotent_hint: false, + open_world_hint: true, + }), + ) +} + +fn extract_task_parameters(params: &Value) -> Vec { + params + .get("task_parameters") + .and_then(|v| v.as_array()) + .cloned() + .unwrap_or_default() +} + +fn create_text_instruction_tasks_from_params(task_params: &[Value]) -> Vec { + task_params + .iter() + .map(|task_param| { + let text_instruction = task_param + .get("text_instruction") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let timeout_seconds = task_param + .get("timeout_seconds") + .and_then(|v| v.as_u64()) + .unwrap_or(300); + + let payload = json!({ + "text_instruction": text_instruction + }); + + Task { + id: uuid::Uuid::new_v4().to_string(), + task_type: "text_instruction".to_string(), + timeout_in_seconds: Some(timeout_seconds), + payload, + } + }) + .collect() +} + +fn create_task_execution_payload(tasks: Vec, execution_mode: &str) -> Value { + json!({ + "tasks": tasks, + "execution_mode": execution_mode + }) +} + +pub async fn create_dynamic_task(params: Value) -> ToolCallResult { + let task_params_array = extract_task_parameters(¶ms); + + if task_params_array.is_empty() { + return ToolCallResult::from(Err(ToolError::ExecutionError( + "No task parameters provided".to_string(), + ))); + } + + let tasks = create_text_instruction_tasks_from_params(&task_params_array); + + // Use parallel execution if there are multiple tasks, sequential for single task + let execution_mode = if tasks.len() > 1 { + EXECUTION_MODE_PARALLEL + } else { + EXECUTION_MODE_SEQUENTIAL + }; + + let task_execution_payload = create_task_execution_payload(tasks, execution_mode); + + let tasks_json = match serde_json::to_string(&task_execution_payload) { + Ok(json) => json, + Err(e) => { + return ToolCallResult::from(Err(ToolError::ExecutionError(format!( + "Failed to serialize task list: {}", + e + )))) + } + }; + ToolCallResult::from(Ok(vec![Content::text(tasks_json)])) +} From 2c0d73f05406211af9b1f81e0b7bf27f17c91a05 Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Mon, 14 Jul 2025 16:11:34 -0700 Subject: [PATCH 29/43] draft --- crates/goose/src/agents/agent.rs | 37 ++--- crates/goose/src/agents/mod.rs | 2 - .../sub_recipe_execution_tool/executor/mod.rs | 17 +- .../sub_recipe_execution_tool/lib/mod.rs | 13 +- .../sub_recipe_execute_task_tool.rs | 42 +++-- .../agents/sub_recipe_execution_tool/tasks.rs | 116 +++++++++++--- .../sub_recipe_execution_tool/workers.rs | 14 +- crates/goose/src/agents/subagent.rs | 120 +++++++++------ crates/goose/src/agents/subagent_handler.rs | 145 ++++++++++-------- 9 files changed, 331 insertions(+), 175 deletions(-) diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index b9592bff808a..657877077ea4 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -50,12 +50,9 @@ use mcp_core::{ prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolError, ToolResult, }; -use crate::agents::subagent_tools::SUBAGENT_RUN_TASK_TOOL_NAME; - use super::final_output_tool::FinalOutputTool; use super::platform_tools; use super::router_tools; -use super::subagent_manager::SubAgentManager; use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; const DEFAULT_MAX_TURNS: u32 = 1000; @@ -63,7 +60,7 @@ const DEFAULT_MAX_TURNS: u32 = 1000; /// The main goose Agent pub struct Agent { pub(super) provider: Mutex>>, - pub(super) extension_manager: RwLock, + pub(super) extension_manager: Arc>, pub(super) sub_recipe_manager: Mutex, pub(super) final_output_tool: Mutex>, pub(super) frontend_tools: Mutex>, @@ -76,7 +73,7 @@ pub struct Agent { pub(super) tool_monitor: Mutex>, pub(super) router_tool_selector: Mutex>>>, pub(super) scheduler_service: Mutex>>, - pub(super) subagent_manager: Mutex>, + pub(super) mcp_tx: Mutex>, pub(super) mcp_notification_rx: Arc>>, } @@ -137,7 +134,7 @@ impl Agent { Self { provider: Mutex::new(None), - extension_manager: RwLock::new(ExtensionManager::new()), + extension_manager: Arc::new(RwLock::new(ExtensionManager::new())), sub_recipe_manager: Mutex::new(SubRecipeManager::new()), final_output_tool: Mutex::new(None), frontend_tools: Mutex::new(HashMap::new()), @@ -151,7 +148,7 @@ impl Agent { router_tool_selector: Mutex::new(None), scheduler_service: Mutex::new(None), // Initialize with MCP notification support - subagent_manager: Mutex::new(Some(SubAgentManager::new(mcp_tx))), + mcp_tx: Mutex::new(mcp_tx), mcp_notification_rx: Arc::new(Mutex::new(mcp_rx)), } } @@ -296,7 +293,17 @@ impl Agent { .dispatch_sub_recipe_tool_call(&tool_call.name, tool_call.arguments.clone()) .await } else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME { - sub_recipe_execute_task_tool::run_tasks(tool_call.arguments.clone()).await + // Get the provider and extension manager for text instruction tasks + let provider = self.provider().await.ok(); + let extension_manager = Some(Arc::clone(&self.extension_manager)); + + sub_recipe_execute_task_tool::run_tasks( + tool_call.arguments.clone(), + self.mcp_tx.lock().await.clone(), + provider, + extension_manager, + ) + .await } else if tool_call.name == DYNAMIC_TASK_TOOL_NAME_PREFIX { create_dynamic_task(tool_call.arguments.clone()).await } else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { @@ -314,11 +321,6 @@ impl Agent { ) } else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME { ToolCallResult::from(extension_manager.search_available_extensions().await) - } else if tool_call.name == SUBAGENT_RUN_TASK_TOOL_NAME { - ToolCallResult::from( - self.handle_run_subagent_task(tool_call.arguments.clone()) - .await, - ) } else if self.is_frontend_tool(&tool_call.name).await { // For frontend tools, return an error indicating we need frontend execution ToolCallResult::from(Err(ToolError::ExecutionError( @@ -1042,15 +1044,6 @@ impl Agent { let mut current_provider = self.provider.lock().await; *current_provider = Some(provider.clone()); - // Initialize subagent manager with MCP notification support - // Need to recreate the MCP channel since we're replacing the manager - let (mcp_tx, mcp_rx) = mpsc::channel(100); - { - let mut rx_guard = self.mcp_notification_rx.lock().await; - *rx_guard = mcp_rx; - } - *self.subagent_manager.lock().await = Some(SubAgentManager::new(mcp_tx)); - self.update_router_tool_selector(Some(provider), None) .await?; Ok(()) diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index 353e57acde12..60a0b4ee2eae 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -15,7 +15,6 @@ pub mod sub_recipe_execution_tool; pub mod sub_recipe_manager; pub mod subagent; pub mod subagent_handler; -pub mod subagent_manager; pub mod subagent_tools; pub mod subagent_types; mod tool_execution; @@ -28,6 +27,5 @@ pub use extension::ExtensionConfig; pub use extension_manager::ExtensionManager; pub use prompt_manager::PromptManager; pub use subagent::{SubAgent, SubAgentConfig, SubAgentProgress, SubAgentStatus}; -pub use subagent_manager::SubAgentManager; pub use subagent_types::SpawnSubAgentArgs; pub use types::{FrontendTool, SessionConfig}; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs index bb73b73e9c54..c5035e671fe4 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use tokio::sync::mpsc; use tokio::time::Instant; +use crate::agents::extension_manager::ExtensionManager; use crate::agents::sub_recipe_execution_tool::lib::{ ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; @@ -12,6 +13,7 @@ use crate::agents::sub_recipe_execution_tool::task_execution_tracker::{ }; use crate::agents::sub_recipe_execution_tool::tasks::process_task; use crate::agents::sub_recipe_execution_tool::workers::spawn_worker; +use crate::providers::base::Provider; #[cfg(test)] mod tests; @@ -22,6 +24,9 @@ const DEFAULT_MAX_WORKERS: usize = 10; pub async fn execute_single_task( task: &Task, notifier: mpsc::Sender, + mcp_tx: mpsc::Sender, + provider: Option>, + extension_manager: Option>>, ) -> ExecutionResponse { let start_time = Instant::now(); let task_execution_tracker = Arc::new(TaskExecutionTracker::new( @@ -29,7 +34,14 @@ pub async fn execute_single_task( DisplayMode::SingleTaskOutput, notifier, )); - let result = process_task(task, task_execution_tracker).await; + let result = process_task( + task, + task_execution_tracker, + mcp_tx, + provider, + extension_manager, + ) + .await; let execution_time = start_time.elapsed().as_millis(); let stats = calculate_stats(&[result.clone()], execution_time); @@ -43,6 +55,7 @@ pub async fn execute_single_task( pub async fn execute_tasks_in_parallel( tasks: Vec, notifier: mpsc::Sender, + mcp_tx: mpsc::Sender, ) -> ExecutionResponse { let task_execution_tracker = Arc::new(TaskExecutionTracker::new( tasks.clone(), @@ -70,7 +83,7 @@ pub async fn execute_tasks_in_parallel( let worker_count = std::cmp::min(task_count, DEFAULT_MAX_WORKERS); let mut worker_handles = Vec::new(); for i in 0..worker_count { - let handle = spawn_worker(shared_state.clone(), i); + let handle = spawn_worker(shared_state.clone(), i, mcp_tx.clone()); worker_handles.push(handle); } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs index 746e4ecb151c..63678e709050 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs @@ -8,14 +8,21 @@ pub use crate::agents::sub_recipe_execution_tool::task_types::{ #[cfg(test)] mod tests; +use crate::agents::extension_manager::ExtensionManager; +use crate::providers::base::Provider; use mcp_core::protocol::JsonRpcMessage; use serde_json::Value; +use std::sync::Arc; use tokio::sync::mpsc; +use tokio::sync::RwLock; pub async fn execute_tasks( input: Value, execution_mode: ExecutionMode, notifier: mpsc::Sender, + mcp_tx: mpsc::Sender, + provider: Option>, + extension_manager: Option>>, ) -> Result { let tasks: Vec = serde_json::from_value(input.get("tasks").ok_or("Missing tasks field")?.clone()) @@ -25,14 +32,16 @@ pub async fn execute_tasks( match execution_mode { ExecutionMode::Sequential => { if task_count == 1 { - let response = execute_single_task(&tasks[0], notifier).await; + let response = + execute_single_task(&tasks[0], notifier, mcp_tx, provider, extension_manager) + .await; handle_response(response) } else { Err("Sequential execution mode requires exactly one task".to_string()) } } ExecutionMode::Parallel => { - let response: ExecutionResponse = execute_tasks_in_parallel(tasks, notifier).await; + let response: ExecutionResponse = execute_tasks_in_parallel(tasks, notifier, mcp_tx).await; handle_response(response) } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 62054cca18e6..9318a0c54b39 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -1,13 +1,19 @@ use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; use serde_json::Value; +use crate::agents::extension_manager::ExtensionManager; use crate::agents::{ sub_recipe_execution_tool::lib::execute_tasks, sub_recipe_execution_tool::task_types::ExecutionMode, tool_execution::ToolCallResult, }; +use crate::providers::base::Provider; use mcp_core::protocol::JsonRpcMessage; +use std::sync::Arc; use tokio::sync::mpsc; use tokio_stream; +use std::future::Future; +use std::pin::Pin; +use tokio::sync::RwLock; pub const SUB_RECIPE_EXECUTE_TASK_TOOL_NAME: &str = "sub_recipe__execute_task"; pub fn create_sub_recipe_execute_task_tool() -> Tool { @@ -31,7 +37,7 @@ IMPLEMENTATION: EXAMPLES: User Intent Based: - User: 'get weather and tell me a joke' → Sequential (2 separate tool calls, 1 task each) -- User: 'get weather and joke in parallel' → Parallel (1 tool call with array of 2 tasks) +- User: 'get weather and joke in parallel' → Parallel (1 tool call with task array) - User: 'run these simultaneously' → Parallel (1 tool call with task array) - User: 'do task A then task B' → Sequential (2 separate tool calls) @@ -111,17 +117,30 @@ Pre-created Task Based: ) } -pub async fn run_tasks(execute_data: Value) -> ToolCallResult { +pub async fn run_tasks( + execute_data: Value, + mcp_tx: mpsc::Sender, + provider: Option>, + extension_manager: Option>>, +) -> ToolCallResult { let (notification_tx, notification_rx) = mpsc::channel::(100); - let result_future = async move { - let execute_data_clone = execute_data.clone(); - let execution_mode = execute_data_clone - .get("execution_mode") - .and_then(|v| serde_json::from_value::(v.clone()).ok()) - .unwrap_or_default(); + let execution_mode = execute_data + .get("execution_mode") + .and_then(|v| serde_json::from_value::(v.clone()).ok()) + .unwrap_or_default(); - match execute_tasks(execute_data, execution_mode, notification_tx).await { + let result_future = async move { + match execute_tasks( + execute_data, + execution_mode, + notification_tx, + mcp_tx, + provider, + extension_manager, + ) + .await + { Ok(result) => { let output = serde_json::to_string(&result).unwrap(); Ok(vec![Content::text(output)]) @@ -130,11 +149,8 @@ pub async fn run_tasks(execute_data: Value) -> ToolCallResult { } }; - // Convert receiver to stream - let notification_stream = tokio_stream::wrappers::ReceiverStream::new(notification_rx); - ToolCallResult { result: Box::new(Box::pin(result_future)), - notification_stream: Some(Box::new(notification_stream)), + notification_stream: Some(Box::new(tokio_stream::wrappers::ReceiverStream::new(notification_rx))), } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index fb41b0632d18..ec9da009b5c4 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -4,16 +4,25 @@ use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; +use tokio::sync::mpsc; use tokio::time::timeout; +use crate::agents::extension_manager::ExtensionManager; use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskResult, TaskStatus}; +use crate::agents::subagent_handler::run_complete_subagent_task; +use crate::providers::base::Provider; +use mcp_core::protocol::JsonRpcMessage; +use tokio::sync::RwLock; const DEFAULT_TASK_TIMEOUT_SECONDS: u64 = 300; pub async fn process_task( task: &Task, task_execution_tracker: Arc, + mcp_tx: mpsc::Sender, + provider: Option>, + extension_manager: Option>>, ) -> TaskResult { let timeout_in_seconds = task .timeout_in_seconds @@ -24,7 +33,13 @@ pub async fn process_task( let task_execution_tracker_clone = task_execution_tracker.clone(); match timeout( timeout_duration, - get_task_result(task_clone, task_execution_tracker), + get_task_result( + task_clone, + task_execution_tracker, + mcp_tx, + provider, + extension_manager, + ), ) .await { @@ -61,20 +76,86 @@ pub async fn process_task( async fn get_task_result( task: Task, task_execution_tracker: Arc, + mcp_tx: mpsc::Sender, + provider: Option>, + extension_manager: Option>>, ) -> Result { - let (command, output_identifier) = build_command(&task)?; - let (stdout_output, stderr_output, success) = run_command( - command, - &output_identifier, - &task.id, - task_execution_tracker, - ) - .await?; - - if success { - process_output(stdout_output) + if task.task_type == "text_instruction" { + // Handle text_instruction tasks using subagent system + handle_text_instruction_task( + task, + task_execution_tracker, + mcp_tx, + provider, + extension_manager, + ) + .await } else { - Err(format!("Command failed:\n{}", stderr_output)) + // Handle sub_recipe tasks using command execution + let (command, output_identifier) = build_command(&task)?; + let (stdout_output, stderr_output, success) = run_command( + command, + &output_identifier, + &task.id, + task_execution_tracker, + ) + .await?; + + if success { + process_output(stdout_output) + } else { + Err(format!("Command failed:\n{}", stderr_output)) + } + } +} + +async fn handle_text_instruction_task( + task: Task, + task_execution_tracker: Arc, + mcp_tx: mpsc::Sender, + provider: Option>, + extension_manager: Option>>, +) -> Result { + let text_instruction = task + .get_text_instruction() + .ok_or_else(|| format!("Task {}: Missing text_instruction", task.id))?; + + // Check if we have the required dependencies for subagent execution + let (provider, extension_manager) = match (provider, extension_manager) { + (Some(p), Some(em)) => (p, em), + _ => { + return Err( + "Text instruction tasks require provider and extension_manager".to_string(), + ); + } + }; + + // Create arguments for the subagent task + let arguments = serde_json::json!({ + "task": text_instruction, + "instructions": "You are a helpful assistant. Execute the given task and provide a clear, concise response.", + "max_turns": 5, + "timeout_seconds": task.timeout_in_seconds.unwrap_or(300) + }); + + // Execute the text instruction using the subagent system + match run_complete_subagent_task(arguments, mcp_tx, provider, Some(extension_manager)).await { + Ok(contents) => { + // Extract the text content from the result + let result_text = contents + .into_iter() + .filter_map(|content| match content { + mcp_core::Content::Text(text) => Some(text.text), + _ => None, + }) + .collect::>() + .join("\n"); + + Ok(serde_json::json!({ + "result": result_text + })) + } + Err(e) => Err(format!("Subagent execution failed: {}", e)), } } @@ -105,12 +186,9 @@ fn build_command(task: &Task) -> Result<(Command, String), String> { } cmd } else { - let text = task - .get_text_instruction() - .ok_or_else(|| task_error("text_instruction"))?; - let mut cmd = Command::new("goose"); - cmd.arg("run").arg("--text").arg(text); - cmd + // This branch should not be reached for text_instruction tasks anymore + // as they are handled in handle_text_instruction_task + return Err("Text instruction tasks are handled separately".to_string()); }; command.stdout(Stdio::piped()); diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index fefbf0eb82b2..d5aff90e39db 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -1,24 +1,30 @@ use crate::agents::sub_recipe_execution_tool::task_types::{SharedState, Task}; use crate::agents::sub_recipe_execution_tool::tasks::process_task; use std::sync::Arc; +use mcp_core::protocol::JsonRpcMessage; +use tokio::sync::mpsc; async fn receive_task(state: &SharedState) -> Option { let mut receiver = state.task_receiver.lock().await; receiver.recv().await } -pub fn spawn_worker(state: Arc, worker_id: usize) -> tokio::task::JoinHandle<()> { +pub fn spawn_worker( + state: Arc, + worker_id: usize, + mcp_tx: mpsc::Sender, +) -> tokio::task::JoinHandle<()> { state.increment_active_workers(); tokio::spawn(async move { - worker_loop(state, worker_id).await; + worker_loop(state, worker_id, mcp_tx).await; }) } -async fn worker_loop(state: Arc, _worker_id: usize) { +async fn worker_loop(state: Arc, _worker_id: usize, mcp_tx: mpsc::Sender) { while let Some(task) = receive_task(&state).await { state.task_execution_tracker.start_task(&task.id).await; - let result = process_task(&task, state.task_execution_tracker.clone()).await; + let result = process_task(&task, state.task_execution_tracker.clone(), mcp_tx.clone(), None, None).await; if let Err(e) = state.result_sender.send(result).await { tracing::error!("Worker failed to send result: {}", e); diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs index 0a02e2d1db73..ba21ad323a58 100644 --- a/crates/goose/src/agents/subagent.rs +++ b/crates/goose/src/agents/subagent.rs @@ -104,7 +104,7 @@ impl SubAgent { pub async fn new( config: SubAgentConfig, _provider: Arc, - extension_manager: Arc>, + extension_manager: Option>>, mcp_notification_tx: mpsc::Sender, ) -> Result<(Arc, tokio::task::JoinHandle<()>), anyhow::Error> { debug!("Creating new subagent with id: {}", config.id); @@ -113,23 +113,25 @@ impl SubAgent { let mut recipe_extensions = Vec::new(); // Check if extensions from recipe exist in the extension manager - if let Some(recipe) = &config.recipe { - if let Some(extensions) = &recipe.extensions { - for extension in extensions { - let extension_name = extension.name(); - let existing_extensions = extension_manager.list_extensions().await?; - - if !existing_extensions.contains(&extension_name) { - missing_extensions.push(extension_name); - } else { - recipe_extensions.push(extension_name); + if let Some(extension_manager) = &extension_manager { + if let Some(recipe) = &config.recipe { + if let Some(extensions) = &recipe.extensions { + for extension in extensions { + let extension_name = extension.name(); + let existing_extensions = extension_manager.read().await.list_extensions().await?; + + if !existing_extensions.contains(&extension_name) { + missing_extensions.push(extension_name); + } else { + recipe_extensions.push(extension_name); + } } } + } else { + // If no recipe, inherit all extensions from the parent agent + let existing_extensions = extension_manager.read().await.list_extensions().await?; + recipe_extensions = existing_extensions; } - } else { - // If no recipe, inherit all extensions from the parent agent - let existing_extensions = extension_manager.list_extensions().await?; - recipe_extensions = existing_extensions; } let subagent = Arc::new(SubAgent { @@ -243,7 +245,7 @@ impl SubAgent { &self, message: String, provider: Arc, - extension_manager: Arc>, + extension_manager: Option>>, ) -> Result { debug!("Processing message for subagent {}", self.id); self.send_mcp_notification("message_processing", &format!("Processing: {}", message)) @@ -299,24 +301,28 @@ impl SubAgent { recipe_extensions.len() ); - for extension_name in recipe_extensions.iter() { - match extension_manager - .get_prefixed_tools(Some(extension_name.clone())) - .await - { - Ok(mut ext_tools) => { - debug!( - "Added {} tools from extension {}", - ext_tools.len(), - extension_name - ); - recipe_tools.append(&mut ext_tools); - } - Err(e) => { - debug!( - "Failed to get tools for extension {}: {}", - extension_name, e - ); + if let Some(extension_manager) = &extension_manager { + for extension_name in recipe_extensions.iter() { + match extension_manager + .read() + .await + .get_prefixed_tools(Some(extension_name.clone())) + .await + { + Ok(mut ext_tools) => { + debug!( + "Added {} tools from extension {}", + ext_tools.len(), + extension_name + ); + recipe_tools.append(&mut ext_tools); + } + Err(e) => { + debug!( + "Failed to get tools for extension {}: {}", + extension_name, e + ); + } } } } @@ -330,7 +336,9 @@ impl SubAgent { let mut filtered_tools = Self::filter_subagent_tools(recipe_tools); // Add platform tools (except subagent tools) - Self::add_platform_tools(&mut filtered_tools, &extension_manager).await; + if let Some(extension_manager) = &extension_manager { + Self::add_platform_tools(&mut filtered_tools, &extension_manager.read().await).await; + } debug!( "Subagent {} has {} tools after filtering and adding platform tools", @@ -344,7 +352,11 @@ impl SubAgent { "Subagent {} operating in inheritance mode, using all parent tools", self.id ); - let parent_tools = extension_manager.get_prefixed_tools(None).await?; + let parent_tools = if let Some(extension_manager) = &extension_manager { + extension_manager.read().await.get_prefixed_tools(None).await? + } else { + Vec::new() + }; debug!( "Subagent {} has {} parent tools before filtering", self.id, @@ -353,7 +365,9 @@ impl SubAgent { let mut filtered_tools = Self::filter_subagent_tools(parent_tools); // Add platform tools (except subagent tools) - Self::add_platform_tools(&mut filtered_tools, &extension_manager).await; + if let Some(extension_manager) = &extension_manager { + Self::add_platform_tools(&mut filtered_tools, &extension_manager.read().await).await; + } debug!( "Subagent {} has {} tools after filtering and adding platform tools", @@ -428,18 +442,28 @@ impl SubAgent { // Handle platform tools or dispatch to extension manager let tool_result = if self.is_platform_tool(&tool_call.name) { - self.handle_platform_tool_call( - tool_call.clone(), - &extension_manager, - ) - .await - } else { - match extension_manager - .dispatch_tool_call(tool_call.clone()) + if let Some(extension_manager) = &extension_manager { + self.handle_platform_tool_call( + tool_call.clone(), + &extension_manager.read().await, + ) .await - { - Ok(result) => result.result.await, - Err(e) => Err(ToolError::ExecutionError(e.to_string())), + } else { + Err(ToolError::ExecutionError("No extension manager available".to_string())) + } + } else { + if let Some(extension_manager) = &extension_manager { + match extension_manager + .read() + .await + .dispatch_tool_call(tool_call.clone()) + .await + { + Ok(result) => result.result.await, + Err(e) => Err(ToolError::ExecutionError(e.to_string())), + } + } else { + Err(ToolError::ExecutionError("No extension manager available".to_string())) } }; diff --git a/crates/goose/src/agents/subagent_handler.rs b/crates/goose/src/agents/subagent_handler.rs index f281f7488e4d..019c09f0ef44 100644 --- a/crates/goose/src/agents/subagent_handler.rs +++ b/crates/goose/src/agents/subagent_handler.rs @@ -1,79 +1,98 @@ use anyhow::Result; use mcp_core::{Content, ToolError}; +use mcp_core::protocol::JsonRpcMessage; use serde_json::Value; use std::sync::Arc; +use tokio::sync::mpsc; +use tokio::sync::RwLock; -use crate::agents::subagent_types::SpawnSubAgentArgs; -use crate::agents::Agent; +use crate::agents::extension_manager::ExtensionManager; +use crate::providers::base::Provider; +use crate::agents::subagent::{SubAgent, SubAgentConfig}; -impl Agent { - /// Handle running a complete subagent task (replaces the individual spawn/send/check tools) - pub async fn handle_run_subagent_task( - &self, - arguments: Value, - ) -> Result, ToolError> { - let subagent_manager = self.subagent_manager.lock().await; - let manager = subagent_manager.as_ref().ok_or_else(|| { - ToolError::ExecutionError("Subagent manager not initialized".to_string()) - })?; +/// Standalone function to run a complete subagent task +pub async fn run_complete_subagent_task( + arguments: Value, + mcp_tx: mpsc::Sender, + provider: Arc, + extension_manager: Option>>, +) -> Result, ToolError> { + // Parse arguments - using "task" as the main message parameter + let message = arguments + .get("task") + .and_then(|v| v.as_str()) + .ok_or_else(|| ToolError::ExecutionError("Missing task parameter".to_string()))? + .to_string(); - // Parse arguments - using "task" as the main message parameter - let message = arguments - .get("task") - .and_then(|v| v.as_str()) - .ok_or_else(|| ToolError::ExecutionError("Missing task parameter".to_string()))? - .to_string(); + // Get instructions from arguments + let instructions = arguments + .get("instructions") + .and_then(|v| v.as_str()) + .ok_or_else(|| ToolError::ExecutionError("Missing instructions parameter".to_string()))? + .to_string(); - // Either recipe_name or instructions must be provided - let recipe_name = arguments - .get("recipe_name") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - let instructions = arguments - .get("instructions") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); + // Set max_turns with default of 10 + let max_turns = arguments + .get("max_turns") + .and_then(|v| v.as_u64()) + .unwrap_or(10) as usize; - let mut args = if let Some(recipe_name) = recipe_name { - SpawnSubAgentArgs::new_with_recipe(recipe_name, message.clone()) - } else if let Some(instructions) = instructions { - SpawnSubAgentArgs::new_with_instructions(instructions, message.clone()) - } else { - return Err(ToolError::ExecutionError( - "Either recipe_name or instructions parameter must be provided".to_string(), - )); - }; - - // Set max_turns with default of 10 - let max_turns = arguments - .get("max_turns") - .and_then(|v| v.as_u64()) - .unwrap_or(10) as usize; - args = args.with_max_turns(max_turns); + let timeout = arguments.get("timeout_seconds").and_then(|v| v.as_u64()); - if let Some(timeout) = arguments.get("timeout_seconds").and_then(|v| v.as_u64()) { - args = args.with_timeout(timeout); - } + // Create subagent config with instructions + let mut config = SubAgentConfig::new_with_instructions(instructions); + config = config.with_max_turns(max_turns); + if let Some(timeout) = timeout { + config = config.with_timeout(timeout); + } - // Get the provider from the parent agent - let provider = self - .provider() - .await - .map_err(|e| ToolError::ExecutionError(format!("Failed to get provider: {}", e)))?; + // Create the subagent with the parent agent's provider + let extension_manager_clone = extension_manager.clone(); + let (subagent, handle) = SubAgent::new( + config, + Arc::clone(&provider), + extension_manager, + mcp_tx, + ) + .await + .map_err(|e| ToolError::ExecutionError(format!("Failed to create subagent: {}", e)))?; - // Get the extension manager from the parent agent - let extension_manager = Arc::new(self.extension_manager.read().await); + // Run the complete conversation + let mut conversation_result = String::new(); + let turn_count = 0; - // Run the complete subagent task - match manager - .run_complete_subagent_task(args, provider, extension_manager) - .await - { - Ok(result) => Ok(vec![Content::text(result)]), - Err(e) => Err(ToolError::ExecutionError(format!( - "Failed to run subagent task: {}", - e - ))), + // Execute the subagent task + match subagent + .reply_subagent( + message, + Arc::clone(&provider), + extension_manager_clone, + ) + .await + { + Ok(response) => { + let response_text = response.as_concat_text(); + conversation_result.push_str(&format!( + "\n--- Turn {} ---\n{}", + turn_count + 1, + response_text + )); + conversation_result.push_str(&format!( + "\n[Task completed after {} turns]", + turn_count + 1 + )); + } + Err(e) => { + conversation_result + .push_str(&format!("\n[Error after {} turns: {}]", turn_count, e)); } } + + // Clean up the subagent handle + if let Err(e) = handle.await { + tracing::debug!("Subagent handle cleanup error: {}", e); + } + + // Return the complete conversation result + Ok(vec![Content::text(format!("Subagent task completed:\n{}", conversation_result))]) } From 0e016303a7a00e120b7942f8ec413328e18dde9f Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Tue, 15 Jul 2025 09:47:14 -0700 Subject: [PATCH 30/43] feat: dynamic tasks (#3414) --- crates/goose-cli/src/session/builder.rs | 2 + crates/goose/src/agents/agent.rs | 8 +- .../agents/recipe_tools/dynamic_task_tools.rs | 138 ++++++++++++++++++ crates/goose/src/agents/recipe_tools/mod.rs | 1 + .../agents/recipe_tools/sub_recipe_tools.rs | 4 +- 5 files changed, 149 insertions(+), 4 deletions(-) create mode 100644 crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index b5d3da1d7f45..0377a36e905e 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -187,6 +187,8 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { // Create the agent let agent: Agent = Agent::new(); + + // Sub-recipes if let Some(sub_recipes) = session_config.sub_recipes { agent.add_sub_recipes(sub_recipes).await; } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 38a488a9781a..b9592bff808a 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -9,6 +9,9 @@ use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; use mcp_core::protocol::JsonRpcMessage; use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME}; +use crate::agents::recipe_tools::dynamic_task_tools::{ + create_dynamic_task, create_dynamic_task_tool, DYNAMIC_TASK_TOOL_NAME_PREFIX, +}; use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{ self, SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, }; @@ -53,7 +56,6 @@ use super::final_output_tool::FinalOutputTool; use super::platform_tools; use super::router_tools; use super::subagent_manager::SubAgentManager; -use super::subagent_tools; use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; const DEFAULT_MAX_TURNS: u32 = 1000; @@ -295,6 +297,8 @@ impl Agent { .await } else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME { sub_recipe_execute_task_tool::run_tasks(tool_call.arguments.clone()).await + } else if tool_call.name == DYNAMIC_TASK_TOOL_NAME_PREFIX { + create_dynamic_task(tool_call.arguments.clone()).await } else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { // Check if the tool is read_resource and handle it separately ToolCallResult::from( @@ -559,7 +563,7 @@ impl Agent { // Add subagent tool (only if ALPHA_FEATURES is enabled) let config = Config::global(); if config.get_param::("ALPHA_FEATURES").unwrap_or(false) { - prefixed_tools.push(subagent_tools::run_task_subagent_tool()); + prefixed_tools.push(create_dynamic_task_tool()); } // Add resource tools if supported diff --git a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs new file mode 100644 index 000000000000..77a66bee3fac --- /dev/null +++ b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs @@ -0,0 +1,138 @@ +// ======================================= +// Module: Dynamic Task Tools +// Handles creation of tasks dynamically without sub-recipes +// ======================================= +use crate::agents::recipe_tools::sub_recipe_tools::{ + EXECUTION_MODE_PARALLEL, EXECUTION_MODE_SEQUENTIAL, +}; +use crate::agents::sub_recipe_execution_tool::lib::Task; +use crate::agents::tool_execution::ToolCallResult; +use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; +use serde_json::{json, Value}; + +pub const DYNAMIC_TASK_TOOL_NAME_PREFIX: &str = "dynamic_task__create_task"; + +pub fn create_dynamic_task_tool() -> Tool { + Tool::new( + format!("{}", DYNAMIC_TASK_TOOL_NAME_PREFIX), + format!( + "Creates a dynamic task object(s) based on textual instructions. \ + Provide an array of parameter sets in the 'task_parameters' field:\n\ + - For a single task: provide an array with one parameter set\n\ + - For multiple tasks: provide an array with multiple parameter sets, each with different values\n\n\ + Each task will run the same text instruction but with different parameter values. \ + This is useful when you need to execute the same instruction multiple times with varying inputs. \ + After creating the task list, pass it to the task executor to run all tasks." + ), + json!({ + "type": "object", + "properties": { + "task_parameters": { + "type": "array", + "description": "Array of parameter sets for creating tasks. \ + For a single task, provide an array with one element. \ + For multiple tasks, provide an array with multiple elements, each with different parameter values. \ + If there is no parameter set, provide an empty array.", + "items": { + "type": "object", + "properties": { + "text_instruction": { + "type": "string", + "description": "The text instruction to execute" + }, + "timeout_seconds": { + "type": "integer", + "description": "Optional timeout for the task in seconds (default: 300)", + "minimum": 1 + } + }, + "required": ["text_instruction"] + } + } + } + }), + Some(ToolAnnotations { + title: Some(format!("Dynamic Task Creation")), + read_only_hint: false, + destructive_hint: true, + idempotent_hint: false, + open_world_hint: true, + }), + ) +} + +fn extract_task_parameters(params: &Value) -> Vec { + params + .get("task_parameters") + .and_then(|v| v.as_array()) + .cloned() + .unwrap_or_default() +} + +fn create_text_instruction_tasks_from_params(task_params: &[Value]) -> Vec { + task_params + .iter() + .map(|task_param| { + let text_instruction = task_param + .get("text_instruction") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let timeout_seconds = task_param + .get("timeout_seconds") + .and_then(|v| v.as_u64()) + .unwrap_or(300); + + let payload = json!({ + "text_instruction": text_instruction + }); + + Task { + id: uuid::Uuid::new_v4().to_string(), + task_type: "text_instruction".to_string(), + timeout_in_seconds: Some(timeout_seconds), + payload, + } + }) + .collect() +} + +fn create_task_execution_payload(tasks: Vec, execution_mode: &str) -> Value { + json!({ + "tasks": tasks, + "execution_mode": execution_mode + }) +} + +pub async fn create_dynamic_task(params: Value) -> ToolCallResult { + let task_params_array = extract_task_parameters(¶ms); + + if task_params_array.is_empty() { + return ToolCallResult::from(Err(ToolError::ExecutionError( + "No task parameters provided".to_string(), + ))); + } + + let tasks = create_text_instruction_tasks_from_params(&task_params_array); + + // Use parallel execution if there are multiple tasks, sequential for single task + let execution_mode = if tasks.len() > 1 { + EXECUTION_MODE_PARALLEL + } else { + EXECUTION_MODE_SEQUENTIAL + }; + + let task_execution_payload = create_task_execution_payload(tasks, execution_mode); + + let tasks_json = match serde_json::to_string(&task_execution_payload) { + Ok(json) => json, + Err(e) => { + return ToolCallResult::from(Err(ToolError::ExecutionError(format!( + "Failed to serialize task list: {}", + e + )))) + } + }; + ToolCallResult::from(Ok(vec![Content::text(tasks_json)])) +} diff --git a/crates/goose/src/agents/recipe_tools/mod.rs b/crates/goose/src/agents/recipe_tools/mod.rs index 90603c88488e..6e6f28a80310 100644 --- a/crates/goose/src/agents/recipe_tools/mod.rs +++ b/crates/goose/src/agents/recipe_tools/mod.rs @@ -1,2 +1,3 @@ +pub mod dynamic_task_tools; pub mod param_utils; pub mod sub_recipe_tools; diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index fd0edaeaecd1..bc0347dfd509 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -11,8 +11,8 @@ use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubReci use super::param_utils::prepare_command_params; pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; -const EXECUTION_MODE_PARALLEL: &str = "parallel"; -const EXECUTION_MODE_SEQUENTIAL: &str = "sequential"; +pub const EXECUTION_MODE_PARALLEL: &str = "parallel"; +pub const EXECUTION_MODE_SEQUENTIAL: &str = "sequential"; pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { let input_schema = get_input_schema(sub_recipe).unwrap(); From 500769940c70513896e508e069c67924238bd43b Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Tue, 15 Jul 2025 13:50:50 -0700 Subject: [PATCH 31/43] task config --- crates/goose/src/agents/agent.rs | 12 +- crates/goose/src/agents/mod.rs | 4 +- .../sub_recipe_execution_tool/executor/mod.rs | 22 +- .../sub_recipe_execution_tool/lib/mod.rs | 17 +- .../sub_recipe_execute_task_tool.rs | 15 +- .../agents/sub_recipe_execution_tool/tasks.rs | 90 ++-- .../sub_recipe_execution_tool/workers.rs | 11 +- crates/goose/src/agents/subagent.rs | 391 ++--------------- crates/goose/src/agents/subagent_handler.rs | 55 +-- crates/goose/src/agents/subagent_manager.rs | 404 ------------------ crates/goose/src/agents/task.rs | 74 ++++ 11 files changed, 208 insertions(+), 887 deletions(-) delete mode 100644 crates/goose/src/agents/subagent_manager.rs create mode 100644 crates/goose/src/agents/task.rs diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 657877077ea4..caa0fb043541 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -15,6 +15,7 @@ use crate::agents::recipe_tools::dynamic_task_tools::{ use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{ self, SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, }; +use crate::agents::task::TaskConfig; use crate::agents::sub_recipe_manager::SubRecipeManager; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; use crate::message::Message; @@ -294,14 +295,17 @@ impl Agent { .await } else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME { // Get the provider and extension manager for text instruction tasks + let tools = self.list_tools(None).await; + let extensions = self.list_extensions().await; let provider = self.provider().await.ok(); - let extension_manager = Some(Arc::clone(&self.extension_manager)); + let mcp_tx = self.mcp_tx.lock().await.clone(); + println!("Executing tool call: {:?}", tool_call); + + let task_config = TaskConfig::new(provider, Some(Arc::clone(&self.extension_manager)), tools, extensions, mcp_tx); sub_recipe_execute_task_tool::run_tasks( tool_call.arguments.clone(), - self.mcp_tx.lock().await.clone(), - provider, - extension_manager, + task_config, ) .await } else if tool_call.name == DYNAMIC_TASK_TOOL_NAME_PREFIX { diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index 60a0b4ee2eae..66e87be43386 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -17,6 +17,7 @@ pub mod subagent; pub mod subagent_handler; pub mod subagent_tools; pub mod subagent_types; +mod task; mod tool_execution; mod tool_router_index_manager; pub(crate) mod tool_vectordb; @@ -26,6 +27,7 @@ pub use agent::{Agent, AgentEvent}; pub use extension::ExtensionConfig; pub use extension_manager::ExtensionManager; pub use prompt_manager::PromptManager; -pub use subagent::{SubAgent, SubAgentConfig, SubAgentProgress, SubAgentStatus}; +pub use subagent::{SubAgent, SubAgentProgress, SubAgentStatus}; pub use subagent_types::SpawnSubAgentArgs; +pub use task::TaskConfig; pub use types::{FrontendTool, SessionConfig}; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs index c5035e671fe4..ba999cebfea0 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs @@ -3,8 +3,6 @@ use std::sync::atomic::AtomicUsize; use std::sync::Arc; use tokio::sync::mpsc; use tokio::time::Instant; - -use crate::agents::extension_manager::ExtensionManager; use crate::agents::sub_recipe_execution_tool::lib::{ ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; @@ -13,7 +11,7 @@ use crate::agents::sub_recipe_execution_tool::task_execution_tracker::{ }; use crate::agents::sub_recipe_execution_tool::tasks::process_task; use crate::agents::sub_recipe_execution_tool::workers::spawn_worker; -use crate::providers::base::Provider; +use crate::agents::task::TaskConfig; #[cfg(test)] mod tests; @@ -24,9 +22,7 @@ const DEFAULT_MAX_WORKERS: usize = 10; pub async fn execute_single_task( task: &Task, notifier: mpsc::Sender, - mcp_tx: mpsc::Sender, - provider: Option>, - extension_manager: Option>>, + task_config: TaskConfig, ) -> ExecutionResponse { let start_time = Instant::now(); let task_execution_tracker = Arc::new(TaskExecutionTracker::new( @@ -36,12 +32,14 @@ pub async fn execute_single_task( )); let result = process_task( task, - task_execution_tracker, - mcp_tx, - provider, - extension_manager, + task_execution_tracker.clone(), + task_config ) .await; + + // Complete the task in the tracker + task_execution_tracker.complete_task(&result.task_id, result.clone()).await; + let execution_time = start_time.elapsed().as_millis(); let stats = calculate_stats(&[result.clone()], execution_time); @@ -55,7 +53,7 @@ pub async fn execute_single_task( pub async fn execute_tasks_in_parallel( tasks: Vec, notifier: mpsc::Sender, - mcp_tx: mpsc::Sender, + task_config: TaskConfig, ) -> ExecutionResponse { let task_execution_tracker = Arc::new(TaskExecutionTracker::new( tasks.clone(), @@ -83,7 +81,7 @@ pub async fn execute_tasks_in_parallel( let worker_count = std::cmp::min(task_count, DEFAULT_MAX_WORKERS); let mut worker_handles = Vec::new(); for i in 0..worker_count { - let handle = spawn_worker(shared_state.clone(), i, mcp_tx.clone()); + let handle = spawn_worker(shared_state.clone(), i, task_config.clone()); worker_handles.push(handle); } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs index 63678e709050..fd33476d8215 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs @@ -4,25 +4,16 @@ use crate::agents::sub_recipe_execution_tool::executor::{ pub use crate::agents::sub_recipe_execution_tool::task_types::{ ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; - -#[cfg(test)] -mod tests; - -use crate::agents::extension_manager::ExtensionManager; -use crate::providers::base::Provider; +use crate::agents::task::TaskConfig; use mcp_core::protocol::JsonRpcMessage; use serde_json::Value; -use std::sync::Arc; use tokio::sync::mpsc; -use tokio::sync::RwLock; pub async fn execute_tasks( input: Value, execution_mode: ExecutionMode, notifier: mpsc::Sender, - mcp_tx: mpsc::Sender, - provider: Option>, - extension_manager: Option>>, + task_config: TaskConfig, ) -> Result { let tasks: Vec = serde_json::from_value(input.get("tasks").ok_or("Missing tasks field")?.clone()) @@ -33,7 +24,7 @@ pub async fn execute_tasks( ExecutionMode::Sequential => { if task_count == 1 { let response = - execute_single_task(&tasks[0], notifier, mcp_tx, provider, extension_manager) + execute_single_task(&tasks[0], notifier, task_config) .await; handle_response(response) } else { @@ -41,7 +32,7 @@ pub async fn execute_tasks( } } ExecutionMode::Parallel => { - let response: ExecutionResponse = execute_tasks_in_parallel(tasks, notifier, mcp_tx).await; + let response: ExecutionResponse = execute_tasks_in_parallel(tasks, notifier, task_config).await; handle_response(response) } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 9318a0c54b39..e32c49101f95 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -1,19 +1,14 @@ use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; use serde_json::Value; -use crate::agents::extension_manager::ExtensionManager; use crate::agents::{ sub_recipe_execution_tool::lib::execute_tasks, sub_recipe_execution_tool::task_types::ExecutionMode, tool_execution::ToolCallResult, }; -use crate::providers::base::Provider; +use crate::agents::task::TaskConfig; use mcp_core::protocol::JsonRpcMessage; -use std::sync::Arc; use tokio::sync::mpsc; use tokio_stream; -use std::future::Future; -use std::pin::Pin; -use tokio::sync::RwLock; pub const SUB_RECIPE_EXECUTE_TASK_TOOL_NAME: &str = "sub_recipe__execute_task"; pub fn create_sub_recipe_execute_task_tool() -> Tool { @@ -119,9 +114,7 @@ Pre-created Task Based: pub async fn run_tasks( execute_data: Value, - mcp_tx: mpsc::Sender, - provider: Option>, - extension_manager: Option>>, + task_config: TaskConfig, ) -> ToolCallResult { let (notification_tx, notification_rx) = mpsc::channel::(100); @@ -135,9 +128,7 @@ pub async fn run_tasks( execute_data, execution_mode, notification_tx, - mcp_tx, - provider, - extension_manager, + task_config, ) .await { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index ec9da009b5c4..731e3e300598 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -4,25 +4,19 @@ use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; -use tokio::sync::mpsc; use tokio::time::timeout; -use crate::agents::extension_manager::ExtensionManager; use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskResult, TaskStatus}; use crate::agents::subagent_handler::run_complete_subagent_task; -use crate::providers::base::Provider; -use mcp_core::protocol::JsonRpcMessage; -use tokio::sync::RwLock; +use crate::agents::task::TaskConfig; const DEFAULT_TASK_TIMEOUT_SECONDS: u64 = 300; pub async fn process_task( task: &Task, task_execution_tracker: Arc, - mcp_tx: mpsc::Sender, - provider: Option>, - extension_manager: Option>>, + task_config: TaskConfig, ) -> TaskResult { let timeout_in_seconds = task .timeout_in_seconds @@ -36,9 +30,7 @@ pub async fn process_task( get_task_result( task_clone, task_execution_tracker, - mcp_tx, - provider, - extension_manager, + task_config, ), ) .await @@ -76,18 +68,14 @@ pub async fn process_task( async fn get_task_result( task: Task, task_execution_tracker: Arc, - mcp_tx: mpsc::Sender, - provider: Option>, - extension_manager: Option>>, + task_config: TaskConfig, ) -> Result { if task.task_type == "text_instruction" { // Handle text_instruction tasks using subagent system handle_text_instruction_task( task, task_execution_tracker, - mcp_tx, - provider, - extension_manager, + task_config, ) .await } else { @@ -112,34 +100,38 @@ async fn get_task_result( async fn handle_text_instruction_task( task: Task, task_execution_tracker: Arc, - mcp_tx: mpsc::Sender, - provider: Option>, - extension_manager: Option>>, + task_config: TaskConfig, ) -> Result { let text_instruction = task .get_text_instruction() .ok_or_else(|| format!("Task {}: Missing text_instruction", task.id))?; - // Check if we have the required dependencies for subagent execution - let (provider, extension_manager) = match (provider, extension_manager) { - (Some(p), Some(em)) => (p, em), - _ => { - return Err( - "Text instruction tasks require provider and extension_manager".to_string(), - ); - } - }; + // Start tracking the task + task_execution_tracker.start_task(&task.id).await; + + // Send initial status update + task_execution_tracker + .send_live_output(&task.id, &format!("Starting text instruction task: {}", text_instruction)) + .await; + + // Send progress update + task_execution_tracker + .send_live_output(&task.id, "Initializing subagent for task execution...") + .await; // Create arguments for the subagent task - let arguments = serde_json::json!({ - "task": text_instruction, - "instructions": "You are a helpful assistant. Execute the given task and provide a clear, concise response.", - "max_turns": 5, - "timeout_seconds": task.timeout_in_seconds.unwrap_or(300) + let task_arguments = serde_json::json!({ + "text_instruction": text_instruction, + // "instructions": "You are a helpful assistant. Execute the given task and provide a clear, concise response.", }); // Execute the text instruction using the subagent system - match run_complete_subagent_task(arguments, mcp_tx, provider, Some(extension_manager)).await { + task_execution_tracker + .send_live_output(&task.id, "Executing text instruction with subagent...") + .await; + + println!("Kicking off subagent task! "); + match run_complete_subagent_task(task_arguments, task_config).await { Ok(contents) => { // Extract the text content from the result let result_text = contents @@ -151,11 +143,37 @@ async fn handle_text_instruction_task( .collect::>() .join("\n"); + // Send completion status + task_execution_tracker + .send_live_output(&task.id, "Text instruction task completed successfully") + .await; + + // Send result preview if it's not too long + if result_text.len() > 200 { + let preview = format!("Result preview: {}...", &result_text[..200]); + task_execution_tracker + .send_live_output(&task.id, &preview) + .await; + } else { + task_execution_tracker + .send_live_output(&task.id, &format!("Result: {}", result_text)) + .await; + } + Ok(serde_json::json!({ "result": result_text })) } - Err(e) => Err(format!("Subagent execution failed: {}", e)), + Err(e) => { + let error_msg = format!("Subagent execution failed: {}", e); + + // Send error status + task_execution_tracker + .send_live_output(&task.id, &error_msg) + .await; + + Err(error_msg) + } } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index d5aff90e39db..7dfb76486b81 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -1,8 +1,7 @@ use crate::agents::sub_recipe_execution_tool::task_types::{SharedState, Task}; use crate::agents::sub_recipe_execution_tool::tasks::process_task; +use crate::agents::task::TaskConfig; use std::sync::Arc; -use mcp_core::protocol::JsonRpcMessage; -use tokio::sync::mpsc; async fn receive_task(state: &SharedState) -> Option { let mut receiver = state.task_receiver.lock().await; @@ -12,19 +11,19 @@ async fn receive_task(state: &SharedState) -> Option { pub fn spawn_worker( state: Arc, worker_id: usize, - mcp_tx: mpsc::Sender, + task_config: TaskConfig, ) -> tokio::task::JoinHandle<()> { state.increment_active_workers(); tokio::spawn(async move { - worker_loop(state, worker_id, mcp_tx).await; + worker_loop(state, worker_id, task_config).await; }) } -async fn worker_loop(state: Arc, _worker_id: usize, mcp_tx: mpsc::Sender) { +async fn worker_loop(state: Arc, _worker_id: usize, task_config: TaskConfig) { while let Some(task) = receive_task(&state).await { state.task_execution_tracker.start_task(&task.id).await; - let result = process_task(&task, state.task_execution_tracker.clone(), mcp_tx.clone(), None, None).await; + let result = process_task(&task, state.task_execution_tracker.clone(), task_config.clone()).await; if let Err(e) = state.result_sender.send(result).await { tracing::error!("Worker failed to send result: {}", e); diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs index ba21ad323a58..97e756af3ee7 100644 --- a/crates/goose/src/agents/subagent.rs +++ b/crates/goose/src/agents/subagent.rs @@ -1,26 +1,19 @@ use crate::{ - agents::{extension_manager::ExtensionManager, Agent}, + agents::{Agent, TaskConfig}, message::{Message, MessageContent, ToolRequest}, prompt_template::render_global_file, - providers::base::Provider, providers::errors::ProviderError, - recipe::Recipe, }; use anyhow::anyhow; use chrono::{DateTime, Utc}; use mcp_core::protocol::{JsonRpcMessage, JsonRpcNotification}; -use mcp_core::{handler::ToolError, role::Role, tool::Tool}; +use mcp_core::{handler::ToolError, tool::Tool}; use serde::{Deserialize, Serialize}; use serde_json::{self, json}; use std::{collections::HashMap, sync::Arc}; -use tokio::sync::{mpsc, Mutex, RwLock}; +use tokio::sync::{Mutex, RwLock}; use tracing::{debug, error, instrument}; -use uuid::Uuid; -use crate::agents::platform_tools::{ - self, PLATFORM_LIST_RESOURCES_TOOL_NAME, PLATFORM_READ_RESOURCE_TOOL_NAME, - PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME, -}; use crate::agents::subagent_tools::SUBAGENT_RUN_TASK_TOOL_NAME; /// Status of a subagent @@ -32,48 +25,6 @@ pub enum SubAgentStatus { Terminated, // Manually terminated } -/// Configuration for a subagent -#[derive(Debug)] -pub struct SubAgentConfig { - pub id: String, - pub recipe: Option, - pub instructions: Option, - pub max_turns: Option, - pub timeout_seconds: Option, -} - -impl SubAgentConfig { - pub fn new_with_recipe(recipe: Recipe) -> Self { - Self { - id: Uuid::new_v4().to_string(), - recipe: Some(recipe), - instructions: None, - max_turns: None, - timeout_seconds: None, - } - } - - pub fn new_with_instructions(instructions: String) -> Self { - Self { - id: Uuid::new_v4().to_string(), - recipe: None, - instructions: Some(instructions), - max_turns: None, - timeout_seconds: None, - } - } - - pub fn with_max_turns(mut self, max_turns: usize) -> Self { - self.max_turns = Some(max_turns); - self - } - - pub fn with_timeout(mut self, timeout_seconds: u64) -> Self { - self.timeout_seconds = Some(timeout_seconds); - self - } -} - /// Progress information for a subagent #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SubAgentProgress { @@ -90,60 +41,26 @@ pub struct SubAgent { pub id: String, pub conversation: Arc>>, pub status: Arc>, - pub config: SubAgentConfig, + pub config: TaskConfig, pub turn_count: Arc>, - pub created_at: DateTime, - pub recipe_extensions: Arc>>, - pub missing_extensions: Arc>>, // Track extensions that weren't enabled - pub mcp_notification_tx: mpsc::Sender, // For MCP notifications + pub created_at: DateTime } impl SubAgent { /// Create a new subagent with the given configuration and provider - #[instrument(skip(config, _provider, extension_manager, mcp_notification_tx))] + #[instrument(skip(task_config))] pub async fn new( - config: SubAgentConfig, - _provider: Arc, - extension_manager: Option>>, - mcp_notification_tx: mpsc::Sender, + task_config: TaskConfig, ) -> Result<(Arc, tokio::task::JoinHandle<()>), anyhow::Error> { - debug!("Creating new subagent with id: {}", config.id); - - let mut missing_extensions = Vec::new(); - let mut recipe_extensions = Vec::new(); - - // Check if extensions from recipe exist in the extension manager - if let Some(extension_manager) = &extension_manager { - if let Some(recipe) = &config.recipe { - if let Some(extensions) = &recipe.extensions { - for extension in extensions { - let extension_name = extension.name(); - let existing_extensions = extension_manager.read().await.list_extensions().await?; - - if !existing_extensions.contains(&extension_name) { - missing_extensions.push(extension_name); - } else { - recipe_extensions.push(extension_name); - } - } - } - } else { - // If no recipe, inherit all extensions from the parent agent - let existing_extensions = extension_manager.read().await.list_extensions().await?; - recipe_extensions = existing_extensions; - } - } + debug!("Creating new subagent with id: {}", task_config.id); let subagent = Arc::new(SubAgent { - id: config.id.clone(), + id: task_config.id.clone(), conversation: Arc::new(Mutex::new(Vec::new())), status: Arc::new(RwLock::new(SubAgentStatus::Ready)), - config, + config: task_config, turn_count: Arc::new(Mutex::new(0)), - created_at: Utc::now(), - recipe_extensions: Arc::new(Mutex::new(recipe_extensions)), - missing_extensions: Arc::new(Mutex::new(missing_extensions)), - mcp_notification_tx, + created_at: Utc::now() }); // Send initial MCP notification @@ -211,7 +128,7 @@ impl SubAgent { })), }); - if let Err(e) = self.mcp_notification_tx.send(notification).await { + if let Err(e) = self.config.mcp_tx.send(notification).await { error!( "Failed to send MCP notification from subagent {}: {}", self.id, e @@ -240,17 +157,23 @@ impl SubAgent { } /// Process a message and generate a response using the subagent's provider - #[instrument(skip(self, message, provider, extension_manager))] + #[instrument(skip(self, message))] pub async fn reply_subagent( &self, message: String, - provider: Arc, - extension_manager: Option>>, + task_config: TaskConfig, ) -> Result { debug!("Processing message for subagent {}", self.id); self.send_mcp_notification("message_processing", &format!("Processing: {}", message)) .await; + // Get provider and extension manager from task config + let provider = self.config.provider.as_ref() + .ok_or_else(|| anyhow!("No provider configured for subagent"))?; + + let extension_manager = self.config.extension_manager.as_ref() + .ok_or_else(|| anyhow!("No extension manager configured for subagent"))?; + // Check if we've exceeded max turns { let turn_count = *self.turn_count.lock().await; @@ -290,91 +213,10 @@ impl SubAgent { let mut messages = self.get_conversation().await; // Get tools based on whether we're using a recipe or inheriting from parent - let tools: Vec = if self.config.recipe.is_some() { - // Recipe mode: only get tools from the recipe's extensions - let recipe_extensions = self.recipe_extensions.lock().await; - let mut recipe_tools = Vec::new(); - - debug!( - "Subagent {} operating in recipe mode with {} extensions", - self.id, - recipe_extensions.len() - ); - - if let Some(extension_manager) = &extension_manager { - for extension_name in recipe_extensions.iter() { - match extension_manager - .read() - .await - .get_prefixed_tools(Some(extension_name.clone())) - .await - { - Ok(mut ext_tools) => { - debug!( - "Added {} tools from extension {}", - ext_tools.len(), - extension_name - ); - recipe_tools.append(&mut ext_tools); - } - Err(e) => { - debug!( - "Failed to get tools for extension {}: {}", - extension_name, e - ); - } - } - } - } - - debug!( - "Subagent {} has {} total recipe tools before filtering", - self.id, - recipe_tools.len() - ); - // Filter out subagent tools from recipe tools - let mut filtered_tools = Self::filter_subagent_tools(recipe_tools); - - // Add platform tools (except subagent tools) - if let Some(extension_manager) = &extension_manager { - Self::add_platform_tools(&mut filtered_tools, &extension_manager.read().await).await; - } - - debug!( - "Subagent {} has {} tools after filtering and adding platform tools", - self.id, - filtered_tools.len() - ); - filtered_tools + let tools: Vec = if self.config.tools.is_empty() { + vec![] } else { - // No recipe: inherit all tools from parent (but filter out subagent tools) - debug!( - "Subagent {} operating in inheritance mode, using all parent tools", - self.id - ); - let parent_tools = if let Some(extension_manager) = &extension_manager { - extension_manager.read().await.get_prefixed_tools(None).await? - } else { - Vec::new() - }; - debug!( - "Subagent {} has {} parent tools before filtering", - self.id, - parent_tools.len() - ); - let mut filtered_tools = Self::filter_subagent_tools(parent_tools); - - // Add platform tools (except subagent tools) - if let Some(extension_manager) = &extension_manager { - Self::add_platform_tools(&mut filtered_tools, &extension_manager.read().await).await; - } - - debug!( - "Subagent {} has {} tools after filtering and adding platform tools", - self.id, - filtered_tools.len() - ); - filtered_tools + self.config.tools.clone() }; let toolshim_tools: Vec = vec![]; @@ -385,7 +227,7 @@ impl SubAgent { // Generate response from provider loop { match Agent::generate_response_from_provider( - Arc::clone(&provider), + Arc::clone(provider), &system_prompt, &messages, &tools, @@ -441,31 +283,16 @@ impl SubAgent { .await; // Handle platform tools or dispatch to extension manager - let tool_result = if self.is_platform_tool(&tool_call.name) { - if let Some(extension_manager) = &extension_manager { - self.handle_platform_tool_call( - tool_call.clone(), - &extension_manager.read().await, - ) + let tool_result = + match extension_manager + .read() .await - } else { - Err(ToolError::ExecutionError("No extension manager available".to_string())) - } - } else { - if let Some(extension_manager) = &extension_manager { - match extension_manager - .read() - .await - .dispatch_tool_call(tool_call.clone()) - .await - { - Ok(result) => result.result.await, - Err(e) => Err(ToolError::ExecutionError(e.to_string())), - } - } else { - Err(ToolError::ExecutionError("No extension manager available".to_string())) - } - }; + .dispatch_tool_call(tool_call.clone()) + .await + { + Ok(result) => result.result.await, + Err(e) => Err(ToolError::ExecutionError(e.to_string())), + }; match tool_result { Ok(result) => { @@ -553,58 +380,9 @@ impl SubAgent { Ok(()) } - /// Get formatted conversation for display - pub async fn get_formatted_conversation(&self) -> String { - let conversation = self.conversation.lock().await; - - let mut formatted = format!("=== Subagent {} Conversation ===\n", self.id); - - if let Some(recipe) = &self.config.recipe { - formatted.push_str(&format!("Recipe: {}\n", recipe.title)); - } else if let Some(instructions) = &self.config.instructions { - formatted.push_str(&format!("Instructions: {}\n", instructions)); - } else { - formatted.push_str("Mode: Ad-hoc subagent\n"); - } - - formatted.push_str(&format!( - "Created: {}\n", - self.created_at.format("%Y-%m-%d %H:%M:%S UTC") - )); - - let progress = self.get_progress().await; - - formatted.push_str(&format!("Status: {:?}\n", progress.status)); - formatted.push_str(&format!("Turn: {}", progress.turn)); - if let Some(max_turns) = progress.max_turns { - formatted.push_str(&format!("/{}", max_turns)); - } - formatted.push_str("\n\n"); - - for (i, message) in conversation.iter().enumerate() { - formatted.push_str(&format!( - "{}. {}: {}\n", - i + 1, - match message.role { - Role::User => "User", - Role::Assistant => "Assistant", - }, - message.as_concat_text() - )); - } - - formatted.push_str("=== End Conversation ===\n"); - - formatted - } - - /// Get the list of extensions that weren't enabled - pub async fn get_missing_extensions(&self) -> Vec { - self.missing_extensions.lock().await.clone() - } - /// Filter out subagent spawning tools to prevent infinite recursion - fn filter_subagent_tools(tools: Vec) -> Vec { + fn _filter_subagent_tools(tools: Vec) -> Vec { + // TODO: add this in subagent loop let original_count = tools.len(); let filtered_tools: Vec = tools .into_iter() @@ -629,68 +407,6 @@ impl SubAgent { filtered_tools } - /// Add platform tools to the subagent's tool list (excluding dangerous tools) - async fn add_platform_tools(tools: &mut Vec, extension_manager: &ExtensionManager) { - debug!("Adding safe platform tools to subagent"); - - // Add safe platform tools - subagents can search for extensions but can't manage them or schedules - tools.push(platform_tools::search_available_extensions_tool()); - debug!("Added search_available_extensions tool"); - - // Add resource tools if supported - these are generally safe for subagents - if extension_manager.supports_resources() { - tools.extend([ - platform_tools::read_resource_tool(), - platform_tools::list_resources_tool(), - ]); - debug!("Added 2 resource platform tools"); - } - - // Note: We explicitly do NOT add these tools for security reasons: - // - manage_extensions (could interfere with parent agent's extensions) - // - manage_schedule (could interfere with parent agent's scheduling) - // - subagent spawning tools (prevent recursion) - debug!("Platform tools added successfully (dangerous tools excluded)"); - } - - /// Check if a tool name is a platform tool that subagents can use - fn is_platform_tool(&self, tool_name: &str) -> bool { - matches!( - tool_name, - PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME - | PLATFORM_READ_RESOURCE_TOOL_NAME - | PLATFORM_LIST_RESOURCES_TOOL_NAME - ) - } - - /// Handle platform tool calls that are safe for subagents - async fn handle_platform_tool_call( - &self, - tool_call: mcp_core::tool::ToolCall, - extension_manager: &ExtensionManager, - ) -> Result, ToolError> { - debug!("Handling platform tool: {}", tool_call.name); - - match tool_call.name.as_str() { - PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME => extension_manager - .search_available_extensions() - .await - .map_err(|e| ToolError::ExecutionError(e.to_string())), - PLATFORM_READ_RESOURCE_TOOL_NAME => extension_manager - .read_resource(tool_call.arguments) - .await - .map_err(|e| ToolError::ExecutionError(e.to_string())), - PLATFORM_LIST_RESOURCES_TOOL_NAME => extension_manager - .list_resources(tool_call.arguments) - .await - .map_err(|e| ToolError::ExecutionError(e.to_string())), - _ => Err(ToolError::ExecutionError(format!( - "Platform tool '{}' is not available to subagents for security reasons", - tool_call.name - ))), - } - } - /// Build the system prompt for the subagent using the template async fn build_system_prompt(&self, available_tools: &[Tool]) -> Result { let mut context = HashMap::new(); @@ -702,14 +418,6 @@ impl SubAgent { ); context.insert("subagent_id", serde_json::Value::String(self.id.clone())); - // Add recipe information if available - if let Some(recipe) = &self.config.recipe { - context.insert( - "recipe_title", - serde_json::Value::String(recipe.title.clone()), - ); - } - // Add max turns if configured if let Some(max_turns) = self.config.max_turns { context.insert( @@ -718,33 +426,6 @@ impl SubAgent { ); } - // Add task instructions - let instructions = if let Some(recipe) = &self.config.recipe { - recipe.instructions.as_deref().unwrap_or("") - } else { - self.config.instructions.as_deref().unwrap_or("") - }; - context.insert( - "task_instructions", - serde_json::Value::String(instructions.to_string()), - ); - - // Add available extensions (only if we have a recipe and extensions) - if self.config.recipe.is_some() { - let extensions: Vec = self.recipe_extensions.lock().await.clone(); - if !extensions.is_empty() { - context.insert( - "extensions", - serde_json::Value::Array( - extensions - .into_iter() - .map(serde_json::Value::String) - .collect(), - ), - ); - } - } - // Add available tools with descriptions for better context let tools_with_descriptions: Vec = available_tools .iter() diff --git a/crates/goose/src/agents/subagent_handler.rs b/crates/goose/src/agents/subagent_handler.rs index 019c09f0ef44..159b87d3e898 100644 --- a/crates/goose/src/agents/subagent_handler.rs +++ b/crates/goose/src/agents/subagent_handler.rs @@ -1,58 +1,25 @@ use anyhow::Result; use mcp_core::{Content, ToolError}; -use mcp_core::protocol::JsonRpcMessage; use serde_json::Value; -use std::sync::Arc; -use tokio::sync::mpsc; -use tokio::sync::RwLock; -use crate::agents::extension_manager::ExtensionManager; -use crate::providers::base::Provider; -use crate::agents::subagent::{SubAgent, SubAgentConfig}; +use crate::agents::task::TaskConfig; +use crate::agents::subagent::SubAgent; /// Standalone function to run a complete subagent task pub async fn run_complete_subagent_task( - arguments: Value, - mcp_tx: mpsc::Sender, - provider: Arc, - extension_manager: Option>>, + task_arguments: Value, + task_config: TaskConfig, ) -> Result, ToolError> { // Parse arguments - using "task" as the main message parameter - let message = arguments - .get("task") + let text_instruction = task_arguments + .get("text_instruction") .and_then(|v| v.as_str()) - .ok_or_else(|| ToolError::ExecutionError("Missing task parameter".to_string()))? + .ok_or_else(|| ToolError::ExecutionError("Missing text_instruction parameter".to_string()))? .to_string(); - // Get instructions from arguments - let instructions = arguments - .get("instructions") - .and_then(|v| v.as_str()) - .ok_or_else(|| ToolError::ExecutionError("Missing instructions parameter".to_string()))? - .to_string(); - - // Set max_turns with default of 10 - let max_turns = arguments - .get("max_turns") - .and_then(|v| v.as_u64()) - .unwrap_or(10) as usize; - - let timeout = arguments.get("timeout_seconds").and_then(|v| v.as_u64()); - - // Create subagent config with instructions - let mut config = SubAgentConfig::new_with_instructions(instructions); - config = config.with_max_turns(max_turns); - if let Some(timeout) = timeout { - config = config.with_timeout(timeout); - } - // Create the subagent with the parent agent's provider - let extension_manager_clone = extension_manager.clone(); let (subagent, handle) = SubAgent::new( - config, - Arc::clone(&provider), - extension_manager, - mcp_tx, + task_config.clone(), ) .await .map_err(|e| ToolError::ExecutionError(format!("Failed to create subagent: {}", e)))?; @@ -61,12 +28,12 @@ pub async fn run_complete_subagent_task( let mut conversation_result = String::new(); let turn_count = 0; + println!("Subagent created, executing task..."); // Execute the subagent task match subagent .reply_subagent( - message, - Arc::clone(&provider), - extension_manager_clone, + text_instruction, + task_config ) .await { diff --git a/crates/goose/src/agents/subagent_manager.rs b/crates/goose/src/agents/subagent_manager.rs deleted file mode 100644 index 174faceecc1b..000000000000 --- a/crates/goose/src/agents/subagent_manager.rs +++ /dev/null @@ -1,404 +0,0 @@ -use std::collections::HashMap; -use std::path::Path; -use std::sync::Arc; - -use anyhow::{anyhow, Result}; -use mcp_core::protocol::JsonRpcMessage; -use tokio::sync::{mpsc, Mutex, RwLock}; -use tracing::{debug, error, instrument, warn}; - -use crate::agents::extension_manager::ExtensionManager; -use crate::agents::subagent::{SubAgent, SubAgentConfig, SubAgentProgress, SubAgentStatus}; -use crate::agents::subagent_types::SpawnSubAgentArgs; -use crate::providers::base::Provider; -use crate::recipe::Recipe; - -/// Manages the lifecycle of subagents -pub struct SubAgentManager { - subagents: Arc>>>, - handles: Arc>>>, - mcp_notification_tx: mpsc::Sender, -} - -impl SubAgentManager { - /// Create a new subagent manager - pub fn new(mcp_notification_tx: mpsc::Sender) -> Self { - Self { - subagents: Arc::new(RwLock::new(HashMap::new())), - handles: Arc::new(Mutex::new(HashMap::new())), - mcp_notification_tx, - } - } - - /// Spawn a new interactive subagent - #[instrument(skip(self, args, provider, extension_manager))] - pub async fn spawn_interactive_subagent( - &self, - args: SpawnSubAgentArgs, - provider: Arc, - extension_manager: Arc>, - ) -> Result { - debug!("Spawning interactive subagent"); - - // Create subagent config based on whether we have a recipe or instructions - let mut config = if let Some(recipe_name) = args.recipe_name { - debug!("Using recipe: {}", recipe_name); - // Load the recipe - let recipe = self.load_recipe(&recipe_name).await?; - SubAgentConfig::new_with_recipe(recipe) - } else if let Some(instructions) = args.instructions { - debug!("Using direct instructions"); - SubAgentConfig::new_with_instructions(instructions) - } else { - return Err(anyhow!( - "Either recipe_name or instructions must be provided" - )); - }; - - if let Some(max_turns) = args.max_turns { - config = config.with_max_turns(max_turns); - } - if let Some(timeout) = args.timeout_seconds { - config = config.with_timeout(timeout); - } - - // Create the subagent with the parent agent's provider - let (subagent, handle) = SubAgent::new( - config, - Arc::clone(&provider), - Arc::clone(&extension_manager), - self.mcp_notification_tx.clone(), - ) - .await?; - let subagent_id = subagent.id.clone(); - - // Store the subagent and its handle - { - let mut subagents = self.subagents.write().await; - subagents.insert(subagent_id.clone(), Arc::clone(&subagent)); - } - { - let mut handles = self.handles.lock().await; - handles.insert(subagent_id.clone(), handle); - } - - // Return immediately - no initial message processing - Ok(subagent_id) - } - - /// Get a subagent by ID - pub async fn get_subagent(&self, id: &str) -> Option> { - let subagents = self.subagents.read().await; - subagents.get(id).cloned() - } - - /// List all active subagent IDs - pub async fn list_subagents(&self) -> Vec { - let subagents = self.subagents.read().await; - subagents.keys().cloned().collect() - } - - /// Get status of all subagents - pub async fn get_subagent_status(&self) -> HashMap { - let subagents = self.subagents.read().await; - let mut status_map = HashMap::new(); - - for (id, subagent) in subagents.iter() { - status_map.insert(id.clone(), subagent.get_status().await); - } - - status_map - } - - /// Get progress of all subagents - pub async fn get_subagent_progress(&self) -> HashMap { - let subagents = self.subagents.read().await; - let mut progress_map = HashMap::new(); - - for (id, subagent) in subagents.iter() { - progress_map.insert(id.clone(), subagent.get_progress().await); - } - - progress_map - } - - /// Send a message to a specific subagent - #[instrument(skip(self, message, provider, extension_manager))] - pub async fn send_message_to_subagent( - &self, - subagent_id: &str, - message: String, - provider: Arc, - extension_manager: Arc>, - ) -> Result { - let subagent = self - .get_subagent(subagent_id) - .await - .ok_or_else(|| anyhow!("Subagent {} not found", subagent_id))?; - - // Process the message and get a reply - match subagent - .reply_subagent(message, provider, extension_manager) - .await - { - Ok(response) => Ok(format!( - "Message sent to subagent {}. Response:\n{}", - subagent_id, - response.as_concat_text() - )), - Err(e) => Err(anyhow!("Failed to process message in subagent: {}", e)), - } - } - - /// Terminate a specific subagent - #[instrument(skip(self))] - pub async fn terminate_subagent(&self, id: &str) -> Result<()> { - debug!("Terminating subagent {}", id); - - // Get and terminate the subagent - let subagent = { - let mut subagents = self.subagents.write().await; - subagents.remove(id) - }; - - if let Some(subagent) = subagent { - subagent.terminate().await?; - } else { - warn!("Attempted to terminate non-existent subagent {}", id); - return Err(anyhow!("Subagent {} not found", id)); - } - - // Clean up the background handle - let handle = { - let mut handles = self.handles.lock().await; - handles.remove(id) - }; - - if let Some(handle) = handle { - handle.abort(); - } - - debug!("Subagent {} terminated successfully", id); - Ok(()) - } - - /// Terminate all subagents - #[instrument(skip(self))] - pub async fn terminate_all_subagents(&self) -> Result<()> { - debug!("Terminating all subagents"); - - let subagent_ids: Vec = { - let subagents = self.subagents.read().await; - subagents.keys().cloned().collect() - }; - - for id in subagent_ids { - if let Err(e) = self.terminate_subagent(&id).await { - error!("Failed to terminate subagent {}: {}", id, e); - } - } - - debug!("All subagents terminated"); - Ok(()) - } - - /// Get formatted conversation from a subagent - pub async fn get_subagent_conversation(&self, id: &str) -> Result { - let subagent = self - .get_subagent(id) - .await - .ok_or_else(|| anyhow!("Subagent {} not found", id))?; - - Ok(subagent.get_formatted_conversation().await) - } - - /// Clean up completed or failed subagents - pub async fn cleanup_completed_subagents(&self) -> Result { - let mut completed_ids = Vec::new(); - - // Find completed subagents - { - let subagents = self.subagents.read().await; - for (id, subagent) in subagents.iter() { - if subagent.is_completed().await { - completed_ids.push(id.clone()); - } - } - } - - // Remove completed subagents - let count = completed_ids.len(); - for id in completed_ids { - if let Err(e) = self.terminate_subagent(&id).await { - error!("Failed to cleanup completed subagent {}: {}", id, e); - } - } - - debug!("Cleaned up {} completed subagents", count); - Ok(count) - } - - /// Load a recipe from file - async fn load_recipe(&self, recipe_name: &str) -> Result { - // Try to load from current directory first - let recipe_path = if recipe_name.ends_with(".yaml") || recipe_name.ends_with(".yml") { - recipe_name.to_string() - } else { - format!("{}.yaml", recipe_name) - }; - - if Path::new(&recipe_path).exists() { - let content = tokio::fs::read_to_string(&recipe_path).await?; - let recipe: Recipe = serde_yaml::from_str(&content)?; - return Ok(recipe); - } - - // Try some common recipe locations - let common_paths = [ - format!("recipes/{}", recipe_path), - format!("./recipes/{}", recipe_path), - format!("../recipes/{}", recipe_path), - ]; - - for path in &common_paths { - if Path::new(path).exists() { - let content = tokio::fs::read_to_string(path).await?; - let recipe: Recipe = serde_yaml::from_str(&content)?; - return Ok(recipe); - } - } - - Err(anyhow!( - "Recipe file '{}' not found in current directory or common recipe locations", - recipe_name - )) - } - - /// Get count of active subagents - pub async fn get_active_count(&self) -> usize { - let subagents = self.subagents.read().await; - subagents.len() - } - - /// Check if a subagent exists - pub async fn has_subagent(&self, id: &str) -> bool { - let subagents = self.subagents.read().await; - subagents.contains_key(id) - } - - /// Run a complete subagent task (spawn, execute, cleanup) - #[instrument(skip(self, args, provider, extension_manager))] - pub async fn run_complete_subagent_task( - &self, - args: SpawnSubAgentArgs, - provider: Arc, - extension_manager: Arc>, - ) -> Result { - debug!("Running complete subagent task"); - - // Create subagent config based on whether we have a recipe or instructions - let mut config = if let Some(recipe_name) = args.recipe_name { - debug!("Using recipe: {}", recipe_name); - // Load the recipe - let recipe = self.load_recipe(&recipe_name).await?; - SubAgentConfig::new_with_recipe(recipe) - } else if let Some(instructions) = args.instructions { - debug!("Using direct instructions"); - SubAgentConfig::new_with_instructions(instructions) - } else { - return Err(anyhow!( - "Either recipe_name or instructions must be provided" - )); - }; - - // Set default max_turns if not provided - let max_turns = args.max_turns.unwrap_or(10); - config = config.with_max_turns(max_turns); - - if let Some(timeout) = args.timeout_seconds { - config = config.with_timeout(timeout); - } - - // Create the subagent with the parent agent's provider - let (subagent, handle) = SubAgent::new( - config, - Arc::clone(&provider), - Arc::clone(&extension_manager), - self.mcp_notification_tx.clone(), - ) - .await?; - let subagent_id = subagent.id.clone(); - - // Store the subagent and its handle temporarily - { - let mut subagents = self.subagents.write().await; - subagents.insert(subagent_id.clone(), Arc::clone(&subagent)); - } - { - let mut handles = self.handles.lock().await; - handles.insert(subagent_id.clone(), handle); - } - - // Run the complete conversation - let mut conversation_result = String::new(); - let turn_count = 0; - let current_message = args.message.clone(); - - // For now, we just complete after one turn since we don't have a mechanism - // for the subagent to continue autonomously without user input - // In a future iteration, we could add logic for the subagent to continue - // working on multi-step tasks with proper turn management - match subagent - .reply_subagent( - current_message, - Arc::clone(&provider), - Arc::clone(&extension_manager), - ) - .await - { - Ok(response) => { - let response_text = response.as_concat_text(); - conversation_result.push_str(&format!( - "\n--- Turn {} ---\n{}", - turn_count + 1, - response_text - )); - conversation_result.push_str(&format!( - "\n[Task completed after {} turns]", - turn_count + 1 - )); - } - Err(e) => { - conversation_result - .push_str(&format!("\n[Error after {} turns: {}]", turn_count, e)); - } - } - - // Clean up the subagent - if let Err(e) = self.terminate_subagent(&subagent_id).await { - debug!("Failed to cleanup subagent {}: {}", subagent_id, e); - } - - // Return the complete conversation result - Ok(format!("Subagent task completed:\n{}", conversation_result)) - } -} - -impl Default for SubAgentManager { - fn default() -> Self { - // Create a dummy channel for default implementation - // In practice, this should not be used - SubAgentManager should be created - // with a proper MCP notification sender - let (tx, _rx) = mpsc::channel(1); - Self::new(tx) - } -} - -impl Drop for SubAgentManager { - fn drop(&mut self) { - // Note: In a real implementation, you might want to spawn a task to clean up - // subagents gracefully, but for now we'll rely on the Drop implementations - // of the individual components - debug!("SubAgentManager dropped"); - } -} diff --git a/crates/goose/src/agents/task.rs b/crates/goose/src/agents/task.rs new file mode 100644 index 000000000000..1021c32bda2e --- /dev/null +++ b/crates/goose/src/agents/task.rs @@ -0,0 +1,74 @@ +use crate::providers::base::Provider; +use mcp_core::protocol::JsonRpcMessage; +use mcp_core::tool::Tool; +use std::sync::Arc; +use tokio::sync::{mpsc, RwLock}; +use uuid::Uuid; +use crate::agents::extension_manager::ExtensionManager; +use std::fmt; + +/// Configuration for task execution with all necessary dependencies +#[derive(Clone)] +pub struct TaskConfig { + pub id: String, + pub provider: Option>, + pub extension_manager: Option>>, + pub tools: Vec, + pub extensions: Vec, + pub mcp_tx: mpsc::Sender, + pub max_turns: Option, +} + +impl fmt::Debug for TaskConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TaskConfig") + .field("id", &self.id) + .field("provider", &"") + .field("extension_manager", &"") + .field("tools", &self.tools) + .field("extensions", &self.extensions) + .field("max_turns", &self.max_turns) + .finish() + } +} + +impl TaskConfig { + /// Create a new TaskConfig with all required dependencies + pub fn new( + provider: Option>, + extension_manager: Option>>, + tools: Vec, + extensions: Vec, + mcp_tx: mpsc::Sender, + ) -> Self { + Self { + id: Uuid::new_v4().to_string(), + provider, + extension_manager, + tools, + extensions, + mcp_tx, + max_turns: Some(10), + } + } + + /// Get a reference to the provider + pub fn provider(&self) -> Option<&Arc> { + self.provider.as_ref() + } + + /// Get a reference to the tools + pub fn tools(&self) -> &[Tool] { + &self.tools + } + + /// Get a reference to the extensions + pub fn extensions(&self) -> &[String] { + &self.extensions + } + + /// Get a clone of the MCP sender + pub fn mcp_tx(&self) -> mpsc::Sender { + self.mcp_tx.clone() + } +} \ No newline at end of file From 1f010d18326f65b652cd866d38874813333a2143 Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Tue, 15 Jul 2025 14:02:13 -0700 Subject: [PATCH 32/43] pass extension manager to TaskConfig, remove list_tools to avoid lock block --- crates/goose/src/agents/agent.rs | 7 +++---- crates/goose/src/agents/subagent.rs | 11 ++++++----- crates/goose/src/agents/task.rs | 18 ------------------ 3 files changed, 9 insertions(+), 27 deletions(-) diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index caa0fb043541..5d3b234c5b22 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -294,15 +294,14 @@ impl Agent { .dispatch_sub_recipe_tool_call(&tool_call.name, tool_call.arguments.clone()) .await } else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME { - // Get the provider and extension manager for text instruction tasks - let tools = self.list_tools(None).await; - let extensions = self.list_extensions().await; + println!("About to call provider..."); let provider = self.provider().await.ok(); + println!("About to clone mcp_tx..."); let mcp_tx = self.mcp_tx.lock().await.clone(); println!("Executing tool call: {:?}", tool_call); - let task_config = TaskConfig::new(provider, Some(Arc::clone(&self.extension_manager)), tools, extensions, mcp_tx); + let task_config = TaskConfig::new(provider, Some(Arc::clone(&self.extension_manager)), mcp_tx); sub_recipe_execute_task_tool::run_tasks( tool_call.arguments.clone(), task_config, diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs index 97e756af3ee7..392cdd6d7438 100644 --- a/crates/goose/src/agents/subagent.rs +++ b/crates/goose/src/agents/subagent.rs @@ -213,11 +213,12 @@ impl SubAgent { let mut messages = self.get_conversation().await; // Get tools based on whether we're using a recipe or inheriting from parent - let tools: Vec = if self.config.tools.is_empty() { - vec![] - } else { - self.config.tools.clone() - }; + let tools: Vec = extension_manager + .read() + .await + .get_prefixed_tools(None) + .await + .unwrap_or_default(); let toolshim_tools: Vec = vec![]; diff --git a/crates/goose/src/agents/task.rs b/crates/goose/src/agents/task.rs index 1021c32bda2e..2005829317ea 100644 --- a/crates/goose/src/agents/task.rs +++ b/crates/goose/src/agents/task.rs @@ -13,8 +13,6 @@ pub struct TaskConfig { pub id: String, pub provider: Option>, pub extension_manager: Option>>, - pub tools: Vec, - pub extensions: Vec, pub mcp_tx: mpsc::Sender, pub max_turns: Option, } @@ -25,8 +23,6 @@ impl fmt::Debug for TaskConfig { .field("id", &self.id) .field("provider", &"") .field("extension_manager", &"") - .field("tools", &self.tools) - .field("extensions", &self.extensions) .field("max_turns", &self.max_turns) .finish() } @@ -37,16 +33,12 @@ impl TaskConfig { pub fn new( provider: Option>, extension_manager: Option>>, - tools: Vec, - extensions: Vec, mcp_tx: mpsc::Sender, ) -> Self { Self { id: Uuid::new_v4().to_string(), provider, extension_manager, - tools, - extensions, mcp_tx, max_turns: Some(10), } @@ -57,16 +49,6 @@ impl TaskConfig { self.provider.as_ref() } - /// Get a reference to the tools - pub fn tools(&self) -> &[Tool] { - &self.tools - } - - /// Get a reference to the extensions - pub fn extensions(&self) -> &[String] { - &self.extensions - } - /// Get a clone of the MCP sender pub fn mcp_tx(&self) -> mpsc::Sender { self.mcp_tx.clone() From dc0ec41b86eb069f35f85bf56cf1adbb9345d357 Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Tue, 15 Jul 2025 14:10:08 -0700 Subject: [PATCH 33/43] control output --- crates/goose/src/agents/agent.rs | 15 ++---- .../sub_recipe_execution_tool/executor/mod.rs | 25 ++++----- .../sub_recipe_execution_tool/lib/mod.rs | 7 ++- .../sub_recipe_execute_task_tool.rs | 20 +++---- .../agents/sub_recipe_execution_tool/tasks.rs | 52 +------------------ .../sub_recipe_execution_tool/workers.rs | 7 ++- crates/goose/src/agents/subagent.rs | 33 +++++++----- crates/goose/src/agents/subagent_handler.rs | 26 ++++------ crates/goose/src/agents/task.rs | 6 +-- 9 files changed, 64 insertions(+), 127 deletions(-) diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 5d3b234c5b22..ccb7c95f8615 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -15,8 +15,8 @@ use crate::agents::recipe_tools::dynamic_task_tools::{ use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{ self, SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, }; -use crate::agents::task::TaskConfig; use crate::agents::sub_recipe_manager::SubRecipeManager; +use crate::agents::task::TaskConfig; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; use crate::message::Message; use crate::permission::permission_judge::check_tool_permissions; @@ -294,19 +294,12 @@ impl Agent { .dispatch_sub_recipe_tool_call(&tool_call.name, tool_call.arguments.clone()) .await } else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME { - println!("About to call provider..."); let provider = self.provider().await.ok(); - println!("About to clone mcp_tx..."); let mcp_tx = self.mcp_tx.lock().await.clone(); - println!("Executing tool call: {:?}", tool_call); - - let task_config = TaskConfig::new(provider, Some(Arc::clone(&self.extension_manager)), mcp_tx); - sub_recipe_execute_task_tool::run_tasks( - tool_call.arguments.clone(), - task_config, - ) - .await + let task_config = + TaskConfig::new(provider, Some(Arc::clone(&self.extension_manager)), mcp_tx); + sub_recipe_execute_task_tool::run_tasks(tool_call.arguments.clone(), task_config).await } else if tool_call.name == DYNAMIC_TASK_TOOL_NAME_PREFIX { create_dynamic_task(tool_call.arguments.clone()).await } else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs index ba999cebfea0..b265f4aebbb4 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs @@ -1,8 +1,3 @@ -use mcp_core::protocol::JsonRpcMessage; -use std::sync::atomic::AtomicUsize; -use std::sync::Arc; -use tokio::sync::mpsc; -use tokio::time::Instant; use crate::agents::sub_recipe_execution_tool::lib::{ ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; @@ -12,6 +7,11 @@ use crate::agents::sub_recipe_execution_tool::task_execution_tracker::{ use crate::agents::sub_recipe_execution_tool::tasks::process_task; use crate::agents::sub_recipe_execution_tool::workers::spawn_worker; use crate::agents::task::TaskConfig; +use mcp_core::protocol::JsonRpcMessage; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; +use tokio::sync::mpsc; +use tokio::time::Instant; #[cfg(test)] mod tests; @@ -30,16 +30,13 @@ pub async fn execute_single_task( DisplayMode::SingleTaskOutput, notifier, )); - let result = process_task( - task, - task_execution_tracker.clone(), - task_config - ) - .await; - + let result = process_task(task, task_execution_tracker.clone(), task_config).await; + // Complete the task in the tracker - task_execution_tracker.complete_task(&result.task_id, result.clone()).await; - + task_execution_tracker + .complete_task(&result.task_id, result.clone()) + .await; + let execution_time = start_time.elapsed().as_millis(); let stats = calculate_stats(&[result.clone()], execution_time); diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs index fd33476d8215..04f4117e29ab 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs @@ -23,16 +23,15 @@ pub async fn execute_tasks( match execution_mode { ExecutionMode::Sequential => { if task_count == 1 { - let response = - execute_single_task(&tasks[0], notifier, task_config) - .await; + let response = execute_single_task(&tasks[0], notifier, task_config).await; handle_response(response) } else { Err("Sequential execution mode requires exactly one task".to_string()) } } ExecutionMode::Parallel => { - let response: ExecutionResponse = execute_tasks_in_parallel(tasks, notifier, task_config).await; + let response: ExecutionResponse = + execute_tasks_in_parallel(tasks, notifier, task_config).await; handle_response(response) } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index e32c49101f95..64ec6ab41cc6 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -1,11 +1,11 @@ use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; use serde_json::Value; +use crate::agents::task::TaskConfig; use crate::agents::{ sub_recipe_execution_tool::lib::execute_tasks, sub_recipe_execution_tool::task_types::ExecutionMode, tool_execution::ToolCallResult, }; -use crate::agents::task::TaskConfig; use mcp_core::protocol::JsonRpcMessage; use tokio::sync::mpsc; use tokio_stream; @@ -112,10 +112,7 @@ Pre-created Task Based: ) } -pub async fn run_tasks( - execute_data: Value, - task_config: TaskConfig, -) -> ToolCallResult { +pub async fn run_tasks(execute_data: Value, task_config: TaskConfig) -> ToolCallResult { let (notification_tx, notification_rx) = mpsc::channel::(100); let execution_mode = execute_data @@ -124,14 +121,7 @@ pub async fn run_tasks( .unwrap_or_default(); let result_future = async move { - match execute_tasks( - execute_data, - execution_mode, - notification_tx, - task_config, - ) - .await - { + match execute_tasks(execute_data, execution_mode, notification_tx, task_config).await { Ok(result) => { let output = serde_json::to_string(&result).unwrap(); Ok(vec![Content::text(output)]) @@ -142,6 +132,8 @@ pub async fn run_tasks( ToolCallResult { result: Box::new(Box::pin(result_future)), - notification_stream: Some(Box::new(tokio_stream::wrappers::ReceiverStream::new(notification_rx))), + notification_stream: Some(Box::new(tokio_stream::wrappers::ReceiverStream::new( + notification_rx, + ))), } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index 731e3e300598..967988b0a534 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -27,11 +27,7 @@ pub async fn process_task( let task_execution_tracker_clone = task_execution_tracker.clone(); match timeout( timeout_duration, - get_task_result( - task_clone, - task_execution_tracker, - task_config, - ), + get_task_result(task_clone, task_execution_tracker, task_config), ) .await { @@ -72,12 +68,7 @@ async fn get_task_result( ) -> Result { if task.task_type == "text_instruction" { // Handle text_instruction tasks using subagent system - handle_text_instruction_task( - task, - task_execution_tracker, - task_config, - ) - .await + handle_text_instruction_task(task, task_execution_tracker, task_config).await } else { // Handle sub_recipe tasks using command execution let (command, output_identifier) = build_command(&task)?; @@ -109,28 +100,12 @@ async fn handle_text_instruction_task( // Start tracking the task task_execution_tracker.start_task(&task.id).await; - // Send initial status update - task_execution_tracker - .send_live_output(&task.id, &format!("Starting text instruction task: {}", text_instruction)) - .await; - - // Send progress update - task_execution_tracker - .send_live_output(&task.id, "Initializing subagent for task execution...") - .await; - // Create arguments for the subagent task let task_arguments = serde_json::json!({ "text_instruction": text_instruction, // "instructions": "You are a helpful assistant. Execute the given task and provide a clear, concise response.", }); - // Execute the text instruction using the subagent system - task_execution_tracker - .send_live_output(&task.id, "Executing text instruction with subagent...") - .await; - - println!("Kicking off subagent task! "); match run_complete_subagent_task(task_arguments, task_config).await { Ok(contents) => { // Extract the text content from the result @@ -143,35 +118,12 @@ async fn handle_text_instruction_task( .collect::>() .join("\n"); - // Send completion status - task_execution_tracker - .send_live_output(&task.id, "Text instruction task completed successfully") - .await; - - // Send result preview if it's not too long - if result_text.len() > 200 { - let preview = format!("Result preview: {}...", &result_text[..200]); - task_execution_tracker - .send_live_output(&task.id, &preview) - .await; - } else { - task_execution_tracker - .send_live_output(&task.id, &format!("Result: {}", result_text)) - .await; - } - Ok(serde_json::json!({ "result": result_text })) } Err(e) => { let error_msg = format!("Subagent execution failed: {}", e); - - // Send error status - task_execution_tracker - .send_live_output(&task.id, &error_msg) - .await; - Err(error_msg) } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index 7dfb76486b81..5fd9ff04ffb5 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -23,7 +23,12 @@ pub fn spawn_worker( async fn worker_loop(state: Arc, _worker_id: usize, task_config: TaskConfig) { while let Some(task) = receive_task(&state).await { state.task_execution_tracker.start_task(&task.id).await; - let result = process_task(&task, state.task_execution_tracker.clone(), task_config.clone()).await; + let result = process_task( + &task, + state.task_execution_tracker.clone(), + task_config.clone(), + ) + .await; if let Err(e) = state.result_sender.send(result).await { tracing::error!("Worker failed to send result: {}", e); diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs index 392cdd6d7438..943a6670cea2 100644 --- a/crates/goose/src/agents/subagent.rs +++ b/crates/goose/src/agents/subagent.rs @@ -43,7 +43,7 @@ pub struct SubAgent { pub status: Arc>, pub config: TaskConfig, pub turn_count: Arc>, - pub created_at: DateTime + pub created_at: DateTime, } impl SubAgent { @@ -60,7 +60,7 @@ impl SubAgent { status: Arc::new(RwLock::new(SubAgentStatus::Ready)), config: task_config, turn_count: Arc::new(Mutex::new(0)), - created_at: Utc::now() + created_at: Utc::now(), }); // Send initial MCP notification @@ -168,10 +168,16 @@ impl SubAgent { .await; // Get provider and extension manager from task config - let provider = self.config.provider.as_ref() + let provider = self + .config + .provider + .as_ref() .ok_or_else(|| anyhow!("No provider configured for subagent"))?; - let extension_manager = self.config.extension_manager.as_ref() + let extension_manager = self + .config + .extension_manager + .as_ref() .ok_or_else(|| anyhow!("No extension manager configured for subagent"))?; // Check if we've exceeded max turns @@ -284,16 +290,15 @@ impl SubAgent { .await; // Handle platform tools or dispatch to extension manager - let tool_result = - match extension_manager - .read() - .await - .dispatch_tool_call(tool_call.clone()) - .await - { - Ok(result) => result.result.await, - Err(e) => Err(ToolError::ExecutionError(e.to_string())), - }; + let tool_result = match extension_manager + .read() + .await + .dispatch_tool_call(tool_call.clone()) + .await + { + Ok(result) => result.result.await, + Err(e) => Err(ToolError::ExecutionError(e.to_string())), + }; match tool_result { Ok(result) => { diff --git a/crates/goose/src/agents/subagent_handler.rs b/crates/goose/src/agents/subagent_handler.rs index 159b87d3e898..99b70c816b4a 100644 --- a/crates/goose/src/agents/subagent_handler.rs +++ b/crates/goose/src/agents/subagent_handler.rs @@ -2,8 +2,8 @@ use anyhow::Result; use mcp_core::{Content, ToolError}; use serde_json::Value; -use crate::agents::task::TaskConfig; use crate::agents::subagent::SubAgent; +use crate::agents::task::TaskConfig; /// Standalone function to run a complete subagent task pub async fn run_complete_subagent_task( @@ -18,11 +18,9 @@ pub async fn run_complete_subagent_task( .to_string(); // Create the subagent with the parent agent's provider - let (subagent, handle) = SubAgent::new( - task_config.clone(), - ) - .await - .map_err(|e| ToolError::ExecutionError(format!("Failed to create subagent: {}", e)))?; + let (subagent, handle) = SubAgent::new(task_config.clone()) + .await + .map_err(|e| ToolError::ExecutionError(format!("Failed to create subagent: {}", e)))?; // Run the complete conversation let mut conversation_result = String::new(); @@ -30,13 +28,7 @@ pub async fn run_complete_subagent_task( println!("Subagent created, executing task..."); // Execute the subagent task - match subagent - .reply_subagent( - text_instruction, - task_config - ) - .await - { + match subagent.reply_subagent(text_instruction, task_config).await { Ok(response) => { let response_text = response.as_concat_text(); conversation_result.push_str(&format!( @@ -50,8 +42,7 @@ pub async fn run_complete_subagent_task( )); } Err(e) => { - conversation_result - .push_str(&format!("\n[Error after {} turns: {}]", turn_count, e)); + conversation_result.push_str(&format!("\n[Error after {} turns: {}]", turn_count, e)); } } @@ -61,5 +52,8 @@ pub async fn run_complete_subagent_task( } // Return the complete conversation result - Ok(vec![Content::text(format!("Subagent task completed:\n{}", conversation_result))]) + Ok(vec![Content::text(format!( + "Subagent task completed:\n{}", + conversation_result + ))]) } diff --git a/crates/goose/src/agents/task.rs b/crates/goose/src/agents/task.rs index 2005829317ea..d6cae9bfd945 100644 --- a/crates/goose/src/agents/task.rs +++ b/crates/goose/src/agents/task.rs @@ -1,11 +1,11 @@ +use crate::agents::extension_manager::ExtensionManager; use crate::providers::base::Provider; use mcp_core::protocol::JsonRpcMessage; use mcp_core::tool::Tool; +use std::fmt; use std::sync::Arc; use tokio::sync::{mpsc, RwLock}; use uuid::Uuid; -use crate::agents::extension_manager::ExtensionManager; -use std::fmt; /// Configuration for task execution with all necessary dependencies #[derive(Clone)] @@ -53,4 +53,4 @@ impl TaskConfig { pub fn mcp_tx(&self) -> mpsc::Sender { self.mcp_tx.clone() } -} \ No newline at end of file +} From 0c513b4ff521bd81a3214d5aba261afbab5a7bef Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Tue, 15 Jul 2025 15:37:44 -0700 Subject: [PATCH 34/43] mute notifications' --- crates/goose-cli/src/session/mod.rs | 15 ++-- .../src/session/task_execution_display/mod.rs | 4 +- .../session/task_execution_display/tests.rs | 2 +- crates/goose/src/agents/agent.rs | 22 +++--- crates/goose/src/agents/mod.rs | 9 +-- .../agents/recipe_tools/dynamic_task_tools.rs | 33 ++++++--- .../agents/recipe_tools/sub_recipe_tools.rs | 2 +- .../agents/sub_recipe_execution_tool/mod.rs | 9 --- crates/goose/src/agents/subagent.rs | 25 +------ .../executor/mod.rs | 25 ++++--- .../executor/tests.rs | 0 .../lib/mod.rs | 6 +- .../lib/tests.rs | 0 .../src/agents/subagent_execution_tool/mod.rs | 9 +++ .../notification_events.rs | 2 +- .../subagent_execute_task_tool.rs} | 14 ++-- .../task_execution_tracker.rs | 8 +-- .../task_types.rs | 2 +- .../tasks.rs | 6 +- .../utils/mod.rs | 5 +- .../utils/tests.rs | 0 .../workers.rs | 6 +- crates/goose/src/agents/subagent_handler.rs | 38 +++-------- .../{task.rs => subagent_task_config.rs} | 1 - crates/goose/src/agents/subagent_tools.rs | 68 ------------------- crates/goose/src/agents/subagent_types.rs | 42 ------------ 26 files changed, 107 insertions(+), 246 deletions(-) delete mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/mod.rs rename crates/goose/src/agents/{sub_recipe_execution_tool => subagent_execution_tool}/executor/mod.rs (88%) rename crates/goose/src/agents/{sub_recipe_execution_tool => subagent_execution_tool}/executor/tests.rs (100%) rename crates/goose/src/agents/{sub_recipe_execution_tool => subagent_execution_tool}/lib/mod.rs (94%) rename crates/goose/src/agents/{sub_recipe_execution_tool => subagent_execution_tool}/lib/tests.rs (100%) create mode 100644 crates/goose/src/agents/subagent_execution_tool/mod.rs rename crates/goose/src/agents/{sub_recipe_execution_tool => subagent_execution_tool}/notification_events.rs (98%) rename crates/goose/src/agents/{sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs => subagent_execution_tool/subagent_execute_task_tool.rs} (93%) rename crates/goose/src/agents/{sub_recipe_execution_tool => subagent_execution_tool}/task_execution_tracker.rs (97%) rename crates/goose/src/agents/{sub_recipe_execution_tool => subagent_execution_tool}/task_types.rs (97%) rename crates/goose/src/agents/{sub_recipe_execution_tool => subagent_execution_tool}/tasks.rs (97%) rename crates/goose/src/agents/{sub_recipe_execution_tool => subagent_execution_tool}/utils/mod.rs (88%) rename crates/goose/src/agents/{sub_recipe_execution_tool => subagent_execution_tool}/utils/tests.rs (100%) rename crates/goose/src/agents/{sub_recipe_execution_tool => subagent_execution_tool}/workers.rs (84%) rename crates/goose/src/agents/{task.rs => subagent_task_config.rs} (98%) delete mode 100644 crates/goose/src/agents/subagent_tools.rs delete mode 100644 crates/goose/src/agents/subagent_types.rs diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 61e434568e10..33e75d8bb028 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -1031,13 +1031,14 @@ impl Session { // Handle subagent notifications - show immediately if let Some(_id) = subagent_id { - // Show subagent notifications immediately (no buffering) with compact spacing - if interactive { - let _ = progress_bars.hide(); - println!("{}", console::style(&formatted_message).green().dim()); - } else { - progress_bars.log(&formatted_message); - } + // TODO: proper display for subagent notifications + // if interactive { + // let _ = progress_bars.hide(); + // println!("{}", console::style(&formatted_message).green().dim()); + // } else { + // progress_bars.log(&formatted_message); + // } + continue; } else if let Some(ref notification_type) = _notification_type { if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE { if interactive { diff --git a/crates/goose-cli/src/session/task_execution_display/mod.rs b/crates/goose-cli/src/session/task_execution_display/mod.rs index 96d37b76d483..2ab4cb5ddf4d 100644 --- a/crates/goose-cli/src/session/task_execution_display/mod.rs +++ b/crates/goose-cli/src/session/task_execution_display/mod.rs @@ -1,7 +1,7 @@ -use goose::agents::sub_recipe_execution_tool::lib::TaskStatus; -use goose::agents::sub_recipe_execution_tool::notification_events::{ +use goose::agents::subagent_execution_tool::notification_events::{ TaskExecutionNotificationEvent, TaskInfo, }; +use goose::agents::subagent_execution_tool::task_types::TaskStatus; use serde_json::Value; use std::sync::atomic::{AtomicBool, Ordering}; diff --git a/crates/goose-cli/src/session/task_execution_display/tests.rs b/crates/goose-cli/src/session/task_execution_display/tests.rs index fb53285080d3..725d161dff5b 100644 --- a/crates/goose-cli/src/session/task_execution_display/tests.rs +++ b/crates/goose-cli/src/session/task_execution_display/tests.rs @@ -1,5 +1,5 @@ use super::*; -use goose::agents::sub_recipe_execution_tool::notification_events::{ +use goose::agents::subagent_execution_tool::notification_events::{ FailedTaskInfo, TaskCompletionStats, TaskExecutionStats, }; use serde_json::json; diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index ccb7c95f8615..cfc56805374c 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -12,11 +12,10 @@ use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_ use crate::agents::recipe_tools::dynamic_task_tools::{ create_dynamic_task, create_dynamic_task_tool, DYNAMIC_TASK_TOOL_NAME_PREFIX, }; -use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{ - self, SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, -}; use crate::agents::sub_recipe_manager::SubRecipeManager; -use crate::agents::task::TaskConfig; +use crate::agents::subagent_execution_tool::subagent_execute_task_tool::{ + self, SUBAGENT_EXECUTE_TASK_TOOL_NAME, +}; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; use crate::message::Message; use crate::permission::permission_judge::check_tool_permissions; @@ -55,6 +54,7 @@ use super::final_output_tool::FinalOutputTool; use super::platform_tools; use super::router_tools; use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; +use crate::agents::subagent_task_config::TaskConfig; const DEFAULT_MAX_TURNS: u32 = 1000; @@ -293,13 +293,13 @@ impl Agent { sub_recipe_manager .dispatch_sub_recipe_tool_call(&tool_call.name, tool_call.arguments.clone()) .await - } else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME { + } else if tool_call.name == SUBAGENT_EXECUTE_TASK_TOOL_NAME { let provider = self.provider().await.ok(); let mcp_tx = self.mcp_tx.lock().await.clone(); let task_config = TaskConfig::new(provider, Some(Arc::clone(&self.extension_manager)), mcp_tx); - sub_recipe_execute_task_tool::run_tasks(tool_call.arguments.clone(), task_config).await + subagent_execute_task_tool::run_tasks(tool_call.arguments.clone(), task_config).await } else if tool_call.name == DYNAMIC_TASK_TOOL_NAME_PREFIX { create_dynamic_task(tool_call.arguments.clone()).await } else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { @@ -558,11 +558,8 @@ impl Agent { platform_tools::manage_schedule_tool(), ]); - // Add subagent tool (only if ALPHA_FEATURES is enabled) - let config = Config::global(); - if config.get_param::("ALPHA_FEATURES").unwrap_or(false) { - prefixed_tools.push(create_dynamic_task_tool()); - } + // Dynamic task tool + prefixed_tools.push(create_dynamic_task_tool()); // Add resource tools if supported if extension_manager.supports_resources() { @@ -580,8 +577,7 @@ impl Agent { if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { prefixed_tools.push(final_output_tool.tool()); } - prefixed_tools - .push(sub_recipe_execute_task_tool::create_sub_recipe_execute_task_tool()); + prefixed_tools.push(subagent_execute_task_tool::create_subagent_execute_task_tool()); } prefixed_tools diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index 66e87be43386..ffcc2b9ceffa 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -11,13 +11,11 @@ mod reply_parts; mod router_tool_selector; mod router_tools; mod schedule_tool; -pub mod sub_recipe_execution_tool; pub mod sub_recipe_manager; pub mod subagent; +pub mod subagent_execution_tool; pub mod subagent_handler; -pub mod subagent_tools; -pub mod subagent_types; -mod task; +mod subagent_task_config; mod tool_execution; mod tool_router_index_manager; pub(crate) mod tool_vectordb; @@ -28,6 +26,5 @@ pub use extension::ExtensionConfig; pub use extension_manager::ExtensionManager; pub use prompt_manager::PromptManager; pub use subagent::{SubAgent, SubAgentProgress, SubAgentStatus}; -pub use subagent_types::SpawnSubAgentArgs; -pub use task::TaskConfig; +pub use subagent_task_config::TaskConfig; pub use types::{FrontendTool, SessionConfig}; diff --git a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs index 77a66bee3fac..6bbafe872815 100644 --- a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs +++ b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs @@ -5,7 +5,7 @@ use crate::agents::recipe_tools::sub_recipe_tools::{ EXECUTION_MODE_PARALLEL, EXECUTION_MODE_SEQUENTIAL, }; -use crate::agents::sub_recipe_execution_tool::lib::Task; +use crate::agents::subagent_execution_tool::task_types::Task; use crate::agents::tool_execution::ToolCallResult; use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; use serde_json::{json, Value}; @@ -16,13 +16,30 @@ pub fn create_dynamic_task_tool() -> Tool { Tool::new( format!("{}", DYNAMIC_TASK_TOOL_NAME_PREFIX), format!( - "Creates a dynamic task object(s) based on textual instructions. \ - Provide an array of parameter sets in the 'task_parameters' field:\n\ - - For a single task: provide an array with one parameter set\n\ - - For multiple tasks: provide an array with multiple parameter sets, each with different values\n\n\ - Each task will run the same text instruction but with different parameter values. \ - This is useful when you need to execute the same instruction multiple times with varying inputs. \ - After creating the task list, pass it to the task executor to run all tasks." + "Use this tool to create one or more dynamic tasks from a shared text instruction and varying parameters.\ + How it works: + - Provide a single text instruction + - Use the 'task_parameters' field to pass an array of parameter sets + - Each resulting task will use the same instruction with different parameter values + This is useful when performing the same operation across many inputs (e.g., getting weather for multiple cities, searching multiple slack channels, iterating through various linear tickets, etc). + Once created, these tasks should be passed to the 'subagent__execute_task' tool for execution. Tasks can run sequentially or in parallel. + --- + What is a 'subagent'? + A 'subagent' is a stateless sub-process that executes a single task independently. Use subagents when: + - You want to parallelize similar work across different inputs + - You are not sure your search or operation will succeed on the first try + Each subagent receives a task with a defined payload and returns a result, which is not visible to the user unless explicitly summarized by the system. + --- + Examples of 'task_parameters' for a single task: + text_instruction: Search for the config file in the root directory. + Examples of 'task_parameters' for multiple tasks: + text_instruction: Get weather for Melbourne. + timeout_seconds: 300 + text_instruction: Get weather for Los Angeles. + timeout_seconds: 300 + text_instruction: Get weather for San Francisco. + timeout_seconds: 300 + " ), json!({ "type": "object", diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index bc0347dfd509..68cacc872df7 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -5,7 +5,7 @@ use anyhow::Result; use mcp_core::tool::{Tool, ToolAnnotations}; use serde_json::{json, Map, Value}; -use crate::agents::sub_recipe_execution_tool::lib::Task; +use crate::agents::subagent_execution_tool::task_types::Task; use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubRecipe}; use super::param_utils::prepare_command_params; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs deleted file mode 100644 index b6363ba20d14..000000000000 --- a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -mod executor; -pub mod lib; -pub mod notification_events; -pub mod sub_recipe_execute_task_tool; -mod task_execution_tracker; -mod task_types; -mod tasks; -pub mod utils; -mod workers; diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs index 943a6670cea2..030c787732b6 100644 --- a/crates/goose/src/agents/subagent.rs +++ b/crates/goose/src/agents/subagent.rs @@ -14,8 +14,6 @@ use std::{collections::HashMap, sync::Arc}; use tokio::sync::{Mutex, RwLock}; use tracing::{debug, error, instrument}; -use crate::agents::subagent_tools::SUBAGENT_RUN_TASK_TOOL_NAME; - /// Status of a subagent #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum SubAgentStatus { @@ -389,28 +387,7 @@ impl SubAgent { /// Filter out subagent spawning tools to prevent infinite recursion fn _filter_subagent_tools(tools: Vec) -> Vec { // TODO: add this in subagent loop - let original_count = tools.len(); - let filtered_tools: Vec = tools - .into_iter() - .filter(|tool| { - let should_keep = tool.name != SUBAGENT_RUN_TASK_TOOL_NAME; - if !should_keep { - debug!("Filtering out subagent tool: {}", tool.name); - } - should_keep - }) - .collect(); - - let filtered_count = filtered_tools.len(); - if filtered_count < original_count { - debug!( - "Filtered {} subagent tool(s) from {} total tools", - original_count - filtered_count, - original_count - ); - } - - filtered_tools + tools } /// Build the system prompt for the subagent using the template diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs b/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs similarity index 88% rename from crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs rename to crates/goose/src/agents/subagent_execution_tool/executor/mod.rs index b265f4aebbb4..9a71ad4ad7f9 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs @@ -1,21 +1,18 @@ -use crate::agents::sub_recipe_execution_tool::lib::{ +use crate::agents::subagent_execution_tool::lib::{ ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; -use crate::agents::sub_recipe_execution_tool::task_execution_tracker::{ +use crate::agents::subagent_execution_tool::task_execution_tracker::{ DisplayMode, TaskExecutionTracker, }; -use crate::agents::sub_recipe_execution_tool::tasks::process_task; -use crate::agents::sub_recipe_execution_tool::workers::spawn_worker; -use crate::agents::task::TaskConfig; +use crate::agents::subagent_execution_tool::tasks::process_task; +use crate::agents::subagent_execution_tool::workers::spawn_worker; +use crate::agents::subagent_task_config::TaskConfig; use mcp_core::protocol::JsonRpcMessage; use std::sync::atomic::AtomicUsize; use std::sync::Arc; use tokio::sync::mpsc; use tokio::time::Instant; -#[cfg(test)] -mod tests; - const EXECUTION_STATUS_COMPLETED: &str = "completed"; const DEFAULT_MAX_WORKERS: usize = 10; @@ -171,17 +168,25 @@ fn create_empty_response() -> ExecutionResponse { }, } } - async fn collect_results( result_rx: &mut mpsc::Receiver, task_execution_tracker: Arc, expected_count: usize, ) -> Vec { let mut results = Vec::new(); - while let Some(result) = result_rx.recv().await { + while let Some(mut result) = result_rx.recv().await { + // Truncate data to 650 chars if needed + if let Some(data) = result.data.as_mut() { + if let Some(data_str) = data.as_str() { + if data_str.len() > 650 { + *data = serde_json::Value::String(format!("{}...", &data_str[..650])); + } + } + } task_execution_tracker .complete_task(&result.task_id, result.clone()) .await; + results.push(result); if results.len() >= expected_count { break; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor/tests.rs b/crates/goose/src/agents/subagent_execution_tool/executor/tests.rs similarity index 100% rename from crates/goose/src/agents/sub_recipe_execution_tool/executor/tests.rs rename to crates/goose/src/agents/subagent_execution_tool/executor/tests.rs diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs similarity index 94% rename from crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs rename to crates/goose/src/agents/subagent_execution_tool/lib/mod.rs index 04f4117e29ab..aba50c087fa9 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs @@ -1,10 +1,10 @@ -use crate::agents::sub_recipe_execution_tool::executor::{ +use crate::agents::subagent_execution_tool::executor::{ execute_single_task, execute_tasks_in_parallel, }; -pub use crate::agents::sub_recipe_execution_tool::task_types::{ +pub use crate::agents::subagent_execution_tool::task_types::{ ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; -use crate::agents::task::TaskConfig; +use crate::agents::subagent_task_config::TaskConfig; use mcp_core::protocol::JsonRpcMessage; use serde_json::Value; use tokio::sync::mpsc; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/tests.rs b/crates/goose/src/agents/subagent_execution_tool/lib/tests.rs similarity index 100% rename from crates/goose/src/agents/sub_recipe_execution_tool/lib/tests.rs rename to crates/goose/src/agents/subagent_execution_tool/lib/tests.rs diff --git a/crates/goose/src/agents/subagent_execution_tool/mod.rs b/crates/goose/src/agents/subagent_execution_tool/mod.rs new file mode 100644 index 000000000000..9dab0001d031 --- /dev/null +++ b/crates/goose/src/agents/subagent_execution_tool/mod.rs @@ -0,0 +1,9 @@ +mod executor; +pub mod lib; +pub mod notification_events; +pub mod subagent_execute_task_tool; +pub mod task_execution_tracker; +pub mod task_types; +pub mod tasks; +pub mod utils; +pub mod workers; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs b/crates/goose/src/agents/subagent_execution_tool/notification_events.rs similarity index 98% rename from crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs rename to crates/goose/src/agents/subagent_execution_tool/notification_events.rs index 2a6134ea1a55..632cb976b94c 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs +++ b/crates/goose/src/agents/subagent_execution_tool/notification_events.rs @@ -1,4 +1,4 @@ -use crate::agents::sub_recipe_execution_tool::task_types::TaskStatus; +use crate::agents::subagent_execution_tool::task_types::TaskStatus; use serde::{Deserialize, Serialize}; use serde_json::Value; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs similarity index 93% rename from crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs rename to crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs index 64ec6ab41cc6..4b0cc5eaacd6 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs @@ -1,20 +1,20 @@ use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; use serde_json::Value; -use crate::agents::task::TaskConfig; +use crate::agents::subagent_task_config::TaskConfig; use crate::agents::{ - sub_recipe_execution_tool::lib::execute_tasks, - sub_recipe_execution_tool::task_types::ExecutionMode, tool_execution::ToolCallResult, + subagent_execution_tool::lib::execute_tasks, + subagent_execution_tool::task_types::ExecutionMode, tool_execution::ToolCallResult, }; use mcp_core::protocol::JsonRpcMessage; use tokio::sync::mpsc; use tokio_stream; -pub const SUB_RECIPE_EXECUTE_TASK_TOOL_NAME: &str = "sub_recipe__execute_task"; -pub fn create_sub_recipe_execute_task_tool() -> Tool { +pub const SUBAGENT_EXECUTE_TASK_TOOL_NAME: &str = "subagent__execute_task"; +pub fn create_subagent_execute_task_tool() -> Tool { Tool::new( - SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, - "Only use this tool when you execute sub recipe task. + SUBAGENT_EXECUTE_TASK_TOOL_NAME, + "Only use the subagent__execute_task tool when you execute sub recipe task or dynamic task. EXECUTION STRATEGY DECISION: 1. PRE-CREATED TASKS: If tasks were created by subrecipe__create_task_* tools, check the execution_mode in the response: - If execution_mode is 'parallel', use parallel execution diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs similarity index 97% rename from crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs rename to crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs index b456fd77424f..7d3cf347f28e 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs @@ -5,14 +5,12 @@ use std::sync::Arc; use tokio::sync::{mpsc, RwLock}; use tokio::time::{sleep, Duration, Instant}; -use crate::agents::sub_recipe_execution_tool::notification_events::{ +use crate::agents::subagent_execution_tool::notification_events::{ FailedTaskInfo, TaskCompletionStats, TaskExecutionNotificationEvent, TaskExecutionStats, TaskInfo as EventTaskInfo, }; -use crate::agents::sub_recipe_execution_tool::task_types::{ - Task, TaskInfo, TaskResult, TaskStatus, -}; -use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; +use crate::agents::subagent_execution_tool::task_types::{Task, TaskInfo, TaskResult, TaskStatus}; +use crate::agents::subagent_execution_tool::utils::{count_by_status, get_task_name}; use serde_json::Value; #[derive(Debug, Clone, PartialEq)] diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs b/crates/goose/src/agents/subagent_execution_tool/task_types.rs similarity index 97% rename from crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs rename to crates/goose/src/agents/subagent_execution_tool/task_types.rs index ea31746032d7..92a79318b5fc 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs +++ b/crates/goose/src/agents/subagent_execution_tool/task_types.rs @@ -4,7 +4,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::sync::mpsc; -use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; +use crate::agents::subagent_execution_tool::task_execution_tracker::TaskExecutionTracker; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] #[serde(rename_all = "lowercase")] diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/subagent_execution_tool/tasks.rs similarity index 97% rename from crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs rename to crates/goose/src/agents/subagent_execution_tool/tasks.rs index 967988b0a534..a34b8ebfe57b 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/subagent_execution_tool/tasks.rs @@ -6,10 +6,10 @@ use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; use tokio::time::timeout; -use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; -use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskResult, TaskStatus}; +use crate::agents::subagent_execution_tool::task_execution_tracker::TaskExecutionTracker; +use crate::agents::subagent_execution_tool::task_types::{Task, TaskResult, TaskStatus}; use crate::agents::subagent_handler::run_complete_subagent_task; -use crate::agents::task::TaskConfig; +use crate::agents::subagent_task_config::TaskConfig; const DEFAULT_TASK_TIMEOUT_SECONDS: u64 = 300; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/subagent_execution_tool/utils/mod.rs similarity index 88% rename from crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs rename to crates/goose/src/agents/subagent_execution_tool/utils/mod.rs index 1ead865e6571..2f6791ff8278 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/utils/mod.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use crate::agents::sub_recipe_execution_tool::task_types::{TaskInfo, TaskStatus}; +use crate::agents::subagent_execution_tool::task_types::{TaskInfo, TaskStatus}; pub fn get_task_name(task_info: &TaskInfo) -> &str { task_info @@ -22,6 +22,3 @@ pub fn count_by_status(tasks: &HashMap) -> (usize, usize, usiz ); (total, pending, running, completed, failed) } - -#[cfg(test)] -mod tests; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/subagent_execution_tool/utils/tests.rs similarity index 100% rename from crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs rename to crates/goose/src/agents/subagent_execution_tool/utils/tests.rs diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/subagent_execution_tool/workers.rs similarity index 84% rename from crates/goose/src/agents/sub_recipe_execution_tool/workers.rs rename to crates/goose/src/agents/subagent_execution_tool/workers.rs index 5fd9ff04ffb5..4ae0ab250737 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/subagent_execution_tool/workers.rs @@ -1,6 +1,6 @@ -use crate::agents::sub_recipe_execution_tool::task_types::{SharedState, Task}; -use crate::agents::sub_recipe_execution_tool::tasks::process_task; -use crate::agents::task::TaskConfig; +use crate::agents::subagent_execution_tool::task_types::{SharedState, Task}; +use crate::agents::subagent_execution_tool::tasks::process_task; +use crate::agents::subagent_task_config::TaskConfig; use std::sync::Arc; async fn receive_task(state: &SharedState) -> Option { diff --git a/crates/goose/src/agents/subagent_handler.rs b/crates/goose/src/agents/subagent_handler.rs index 99b70c816b4a..f40a34d4dcfe 100644 --- a/crates/goose/src/agents/subagent_handler.rs +++ b/crates/goose/src/agents/subagent_handler.rs @@ -1,10 +1,9 @@ +use crate::agents::subagent::SubAgent; +use crate::agents::subagent_task_config::TaskConfig; use anyhow::Result; use mcp_core::{Content, ToolError}; use serde_json::Value; -use crate::agents::subagent::SubAgent; -use crate::agents::task::TaskConfig; - /// Standalone function to run a complete subagent task pub async fn run_complete_subagent_task( task_arguments: Value, @@ -22,38 +21,23 @@ pub async fn run_complete_subagent_task( .await .map_err(|e| ToolError::ExecutionError(format!("Failed to create subagent: {}", e)))?; - // Run the complete conversation - let mut conversation_result = String::new(); - let turn_count = 0; - - println!("Subagent created, executing task..."); // Execute the subagent task - match subagent.reply_subagent(text_instruction, task_config).await { + let result = match subagent.reply_subagent(text_instruction, task_config).await { Ok(response) => { let response_text = response.as_concat_text(); - conversation_result.push_str(&format!( - "\n--- Turn {} ---\n{}", - turn_count + 1, - response_text - )); - conversation_result.push_str(&format!( - "\n[Task completed after {} turns]", - turn_count + 1 - )); - } - Err(e) => { - conversation_result.push_str(&format!("\n[Error after {} turns: {}]", turn_count, e)); + Ok(vec![Content::text(response_text)]) } - } + Err(e) => Err(ToolError::ExecutionError(format!( + "Subagent execution failed: {}", + e + ))), + }; // Clean up the subagent handle if let Err(e) = handle.await { tracing::debug!("Subagent handle cleanup error: {}", e); } - // Return the complete conversation result - Ok(vec![Content::text(format!( - "Subagent task completed:\n{}", - conversation_result - ))]) + // Return the result + result } diff --git a/crates/goose/src/agents/task.rs b/crates/goose/src/agents/subagent_task_config.rs similarity index 98% rename from crates/goose/src/agents/task.rs rename to crates/goose/src/agents/subagent_task_config.rs index d6cae9bfd945..261fb82b6f5f 100644 --- a/crates/goose/src/agents/task.rs +++ b/crates/goose/src/agents/subagent_task_config.rs @@ -1,7 +1,6 @@ use crate::agents::extension_manager::ExtensionManager; use crate::providers::base::Provider; use mcp_core::protocol::JsonRpcMessage; -use mcp_core::tool::Tool; use std::fmt; use std::sync::Arc; use tokio::sync::{mpsc, RwLock}; diff --git a/crates/goose/src/agents/subagent_tools.rs b/crates/goose/src/agents/subagent_tools.rs deleted file mode 100644 index f8f35f1ff0ba..000000000000 --- a/crates/goose/src/agents/subagent_tools.rs +++ /dev/null @@ -1,68 +0,0 @@ -use indoc::indoc; -use mcp_core::tool::{Tool, ToolAnnotations}; -use serde_json::json; - -pub const SUBAGENT_RUN_TASK_TOOL_NAME: &str = "subagent__run_task"; - -pub fn run_task_subagent_tool() -> Tool { - Tool::new( - SUBAGENT_RUN_TASK_TOOL_NAME.to_string(), - indoc! {r#" - Spawn a specialized subagent to handle a specific task completely and automatically. - - This tool creates a subagent, processes your task through a complete conversation, - and returns the final result. The subagent is automatically cleaned up after completion. - - You can configure the subagent in two ways: - 1. Using a recipe file that defines instructions, extensions, and behavior - 2. Providing direct instructions for ad-hoc tasks - - The subagent will work autonomously until the task is complete, it reaches max_turns, - or it encounters an error. You'll get the final result without needing to manage - the subagent lifecycle manually. - - Examples: - - "Convert these unittest files to pytest format: file1.py, file2.py" - - "Research the latest developments in AI and provide a comprehensive summary" - - "Review this code for security vulnerabilities and suggest fixes" - - "Refactor this legacy code to use modern Python patterns" - "#} - .to_string(), - json!({ - "type": "object", - "required": ["task"], - "properties": { - "recipe_name": { - "type": "string", - "description": "Name of the recipe file to configure the subagent (e.g., 'research_assistant_recipe.yaml'). Either this or 'instructions' must be provided." - }, - "instructions": { - "type": "string", - "description": "Direct instructions for the subagent's task. Either this or 'recipe_name' must be provided. Example: 'You are a code refactoring assistant. Help convert unittest tests to pytest format.'" - }, - "task": { - "type": "string", - "description": "The task description or initial message for the subagent to work on" - }, - "max_turns": { - "type": "integer", - "description": "Maximum number of conversation turns before auto-completion (default: 10)", - "minimum": 1, - "default": 10 - }, - "timeout_seconds": { - "type": "integer", - "description": "Optional timeout for the entire task in seconds", - "minimum": 1 - } - } - }), - Some(ToolAnnotations { - title: Some("Run subagent task".to_string()), - read_only_hint: false, - destructive_hint: false, - idempotent_hint: false, - open_world_hint: false, - }), - ) -} diff --git a/crates/goose/src/agents/subagent_types.rs b/crates/goose/src/agents/subagent_types.rs deleted file mode 100644 index 1fbc85563f74..000000000000 --- a/crates/goose/src/agents/subagent_types.rs +++ /dev/null @@ -1,42 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SpawnSubAgentArgs { - pub recipe_name: Option, - pub instructions: Option, - pub message: String, - pub max_turns: Option, - pub timeout_seconds: Option, -} - -impl SpawnSubAgentArgs { - pub fn new_with_recipe(recipe_name: String, message: String) -> Self { - Self { - recipe_name: Some(recipe_name), - instructions: None, - message, - max_turns: None, - timeout_seconds: None, - } - } - - pub fn new_with_instructions(instructions: String, message: String) -> Self { - Self { - recipe_name: None, - instructions: Some(instructions), - message, - max_turns: None, - timeout_seconds: None, - } - } - - pub fn with_max_turns(mut self, max_turns: usize) -> Self { - self.max_turns = Some(max_turns); - self - } - - pub fn with_timeout(mut self, timeout_seconds: u64) -> Self { - self.timeout_seconds = Some(timeout_seconds); - self - } -} From 545ea23251c34288d867d1f0ed1cb7775387db7b Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Wed, 16 Jul 2025 11:27:50 -0700 Subject: [PATCH 35/43] merge --- crates/goose/src/agents/agent.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 499ea7b7d03e..37cb915c84ff 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -11,6 +11,7 @@ use mcp_core::protocol::JsonRpcMessage; use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME}; use crate::agents::recipe_tools::dynamic_task_tools::{ create_dynamic_task, create_dynamic_task_tool, DYNAMIC_TASK_TOOL_NAME_PREFIX, +}; use crate::agents::sub_recipe_manager::SubRecipeManager; use crate::agents::subagent_execution_tool::subagent_execute_task_tool::{ self, SUBAGENT_EXECUTE_TASK_TOOL_NAME, From ae0933e1ef7937363e3239e5cb31bdfeac9ff3b3 Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Wed, 16 Jul 2025 13:40:40 -0700 Subject: [PATCH 36/43] stdout --- crates/goose-cli/src/session/mod.rs | 19 +++++++++------- crates/goose-cli/src/session/output.rs | 22 +++++++++++++++++-- .../task_execution_tracker.rs | 12 ++++++++++ 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 8c25c0b3f238..51284c4f97b6 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -7,7 +7,10 @@ mod prompt; mod task_execution_display; mod thinking; -use crate::session::task_execution_display::TASK_EXECUTION_NOTIFICATION_TYPE; +use crate::session::task_execution_display::{ + format_task_execution_notification, TASK_EXECUTION_NOTIFICATION_TYPE, +}; +use std::io::Write; pub use self::export::message_to_markdown; pub use builder::{build_session, SessionBuilderConfig, SessionSettings}; @@ -1072,13 +1075,13 @@ impl Session { // Handle subagent notifications - show immediately if let Some(_id) = subagent_id { // TODO: proper display for subagent notifications - // if interactive { - // let _ = progress_bars.hide(); - // println!("{}", console::style(&formatted_message).green().dim()); - // } else { - // progress_bars.log(&formatted_message); - // } - continue; + if interactive { + let _ = progress_bars.hide(); + println!("{}", console::style(&formatted_message).green().dim()); + } else { + progress_bars.log(&formatted_message); + } + // continue; } else if let Some(ref notification_type) = _notification_type { if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE { if interactive { diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index 06435bd32be3..b3305821cb5e 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -463,8 +463,26 @@ fn print_params(value: &Value, depth: usize, debug: bool) { } } Value::String(s) => { - if !debug && s.len() > get_tool_params_max_length() { - println!("{}{}: {}", indent, style(key).dim(), style("...").dim()); + // Special handling for text_instruction to show more content + let max_length = if key == "text_instruction" { + 200 // Allow longer display for text instructions + } else { + get_tool_params_max_length() + }; + + if !debug && s.len() > max_length { + // For text instructions, show a preview instead of just "..." + if key == "text_instruction" { + let preview = &s[..max_length.saturating_sub(3)]; + println!( + "{}{}: {}", + indent, + style(key).dim(), + style(format!("{}...", preview)).green() + ); + } else { + println!("{}{}: {}", indent, style(key).dim(), style("...").dim()); + } } else { println!("{}{}: {}", indent, style(key).dim(), style(s).green()); } diff --git a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs index 7d3cf347f28e..957ef2983d50 100644 --- a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs @@ -39,6 +39,18 @@ fn format_task_metadata(task_info: &TaskInfo) -> String { }) .collect::>() .join(",") + } else if task_info.task.task_type == "text_instruction" { + // For text_instruction tasks, extract and display the instruction + if let Some(text_instruction) = task_info.task.get_text_instruction() { + // Truncate long instructions to keep the display clean + if text_instruction.len() > 80 { + format!("instruction={}...", &text_instruction[..77]) + } else { + format!("instruction={}", text_instruction) + } + } else { + String::new() + } } else { String::new() } From 8d7e402341a8b8446c3a9ce484533bba21cc50ed Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Wed, 16 Jul 2025 14:46:13 -0700 Subject: [PATCH 37/43] rm probablistic test --- .../goose/tests/pricing_integration_test.rs | 74 ++----------------- 1 file changed, 8 insertions(+), 66 deletions(-) diff --git a/crates/goose/tests/pricing_integration_test.rs b/crates/goose/tests/pricing_integration_test.rs index 9e4472905f6b..6cc719edc704 100644 --- a/crates/goose/tests/pricing_integration_test.rs +++ b/crates/goose/tests/pricing_integration_test.rs @@ -1,74 +1,12 @@ use goose::providers::pricing::{get_model_pricing, initialize_pricing_cache, refresh_pricing}; use std::time::Instant; -#[tokio::test] -async fn test_pricing_cache_performance() { - // Initialize the cache - let start = Instant::now(); - initialize_pricing_cache() - .await - .expect("Failed to initialize pricing cache"); - let init_duration = start.elapsed(); - println!("Cache initialization took: {:?}", init_duration); - - // Test fetching pricing for common models (using actual model names from OpenRouter) - let models = vec![ - ("anthropic", "claude-3.5-sonnet"), - ("openai", "gpt-4o"), - ("openai", "gpt-4o-mini"), - ("google", "gemini-flash-1.5"), - ("anthropic", "claude-sonnet-4"), - ]; - - // First fetch (should hit cache) - let start = Instant::now(); - for (provider, model) in &models { - let pricing = get_model_pricing(provider, model).await; - assert!( - pricing.is_some(), - "Expected pricing for {}/{}", - provider, - model - ); - } - let first_fetch_duration = start.elapsed(); - println!( - "First fetch of {} models took: {:?}", - models.len(), - first_fetch_duration - ); - - // Second fetch (definitely from cache) - let start = Instant::now(); - for (provider, model) in &models { - let pricing = get_model_pricing(provider, model).await; - assert!( - pricing.is_some(), - "Expected pricing for {}/{}", - provider, - model - ); - } - let second_fetch_duration = start.elapsed(); - println!( - "Second fetch of {} models took: {:?}", - models.len(), - second_fetch_duration - ); - - // Cache fetch should be significantly faster - // Note: Both fetches are already very fast (microseconds), so we just ensure - // the second fetch is not slower than the first (allowing for some variance) - assert!( - second_fetch_duration <= first_fetch_duration * 2, - "Cache fetch should not be significantly slower than initial fetch. First: {:?}, Second: {:?}", - first_fetch_duration, - second_fetch_duration - ); -} - #[tokio::test] async fn test_pricing_refresh() { + // Use a unique cache directory for this test to avoid conflicts + let test_cache_dir = format!("/tmp/goose_test_cache_refresh_{}", std::process::id()); + std::env::set_var("GOOSE_CACHE_DIR", &test_cache_dir); + // Initialize first initialize_pricing_cache() .await @@ -90,6 +28,10 @@ async fn test_pricing_refresh() { refreshed_pricing.is_some(), "Expected pricing after refresh" ); + + // Clean up + std::env::remove_var("GOOSE_CACHE_DIR"); + let _ = std::fs::remove_dir_all(&test_cache_dir); } #[tokio::test] From 564369a454e61f36e1bb29665ae7b11d345c4ded Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Wed, 16 Jul 2025 14:50:37 -0700 Subject: [PATCH 38/43] fmt --- crates/goose/tests/pricing_integration_test.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/goose/tests/pricing_integration_test.rs b/crates/goose/tests/pricing_integration_test.rs index 6cc719edc704..655d7701625f 100644 --- a/crates/goose/tests/pricing_integration_test.rs +++ b/crates/goose/tests/pricing_integration_test.rs @@ -6,7 +6,7 @@ async fn test_pricing_refresh() { // Use a unique cache directory for this test to avoid conflicts let test_cache_dir = format!("/tmp/goose_test_cache_refresh_{}", std::process::id()); std::env::set_var("GOOSE_CACHE_DIR", &test_cache_dir); - + // Initialize first initialize_pricing_cache() .await @@ -28,7 +28,7 @@ async fn test_pricing_refresh() { refreshed_pricing.is_some(), "Expected pricing after refresh" ); - + // Clean up std::env::remove_var("GOOSE_CACHE_DIR"); let _ = std::fs::remove_dir_all(&test_cache_dir); From b28ce9c9a7ca3b890d2f14ffd6aee16f86809e3e Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Wed, 16 Jul 2025 16:31:24 -0700 Subject: [PATCH 39/43] rm test --- .../goose/tests/pricing_integration_test.rs | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/crates/goose/tests/pricing_integration_test.rs b/crates/goose/tests/pricing_integration_test.rs index 655d7701625f..15a77abb0002 100644 --- a/crates/goose/tests/pricing_integration_test.rs +++ b/crates/goose/tests/pricing_integration_test.rs @@ -1,6 +1,80 @@ use goose::providers::pricing::{get_model_pricing, initialize_pricing_cache, refresh_pricing}; use std::time::Instant; +#[tokio::test] +async fn test_pricing_cache_performance() { + // Use a unique cache directory for this test to avoid conflicts + let test_cache_dir = format!("/tmp/goose_test_cache_perf_{}", std::process::id()); + std::env::set_var("GOOSE_CACHE_DIR", &test_cache_dir); + + // Initialize the cache + let start = Instant::now(); + initialize_pricing_cache() + .await + .expect("Failed to initialize pricing cache"); + let init_duration = start.elapsed(); + println!("Cache initialization took: {:?}", init_duration); + + // Test fetching pricing for common models (using actual model names from OpenRouter) + let models = vec![ + ("anthropic", "claude-3.5-sonnet"), + ("openai", "gpt-4o"), + ("openai", "gpt-4o-mini"), + ("google", "gemini-flash-1.5"), + ("anthropic", "claude-sonnet-4"), + ]; + + // First fetch (should hit cache) + let start = Instant::now(); + for (provider, model) in &models { + let pricing = get_model_pricing(provider, model).await; + assert!( + pricing.is_some(), + "Expected pricing for {}/{}", + provider, + model + ); + } + let first_fetch_duration = start.elapsed(); + println!( + "First fetch of {} models took: {:?}", + models.len(), + first_fetch_duration + ); + + // Second fetch (definitely from cache) + let start = Instant::now(); + for (provider, model) in &models { + let pricing = get_model_pricing(provider, model).await; + assert!( + pricing.is_some(), + "Expected pricing for {}/{}", + provider, + model + ); + } + let second_fetch_duration = start.elapsed(); + println!( + "Second fetch of {} models took: {:?}", + models.len(), + second_fetch_duration + ); + + // Cache fetch should be significantly faster + // Note: Both fetches are already very fast (microseconds), so we just ensure + // the second fetch is not slower than the first (allowing for some variance) + assert!( + second_fetch_duration <= first_fetch_duration * 2, + "Cache fetch should not be significantly slower than initial fetch. First: {:?}, Second: {:?}", + first_fetch_duration, + second_fetch_duration + ); + + // Clean up + std::env::remove_var("GOOSE_CACHE_DIR"); + let _ = std::fs::remove_dir_all(&test_cache_dir); +} + #[tokio::test] async fn test_pricing_refresh() { // Use a unique cache directory for this test to avoid conflicts @@ -36,6 +110,10 @@ async fn test_pricing_refresh() { #[tokio::test] async fn test_model_not_in_openrouter() { + // Use a unique cache directory for this test to avoid conflicts + let test_cache_dir = format!("/tmp/goose_test_cache_model_{}", std::process::id()); + std::env::set_var("GOOSE_CACHE_DIR", &test_cache_dir); + initialize_pricing_cache() .await .expect("Failed to initialize pricing cache"); @@ -46,12 +124,20 @@ async fn test_model_not_in_openrouter() { pricing.is_none(), "Should return None for non-existent model" ); + + // Clean up + std::env::remove_var("GOOSE_CACHE_DIR"); + let _ = std::fs::remove_dir_all(&test_cache_dir); } #[tokio::test] async fn test_concurrent_access() { use tokio::task; + // Use a unique cache directory for this test to avoid conflicts + let test_cache_dir = format!("/tmp/goose_test_cache_concurrent_{}", std::process::id()); + std::env::set_var("GOOSE_CACHE_DIR", &test_cache_dir); + initialize_pricing_cache() .await .expect("Failed to initialize pricing cache"); @@ -75,4 +161,8 @@ async fn test_concurrent_access() { assert!(has_pricing, "Task {} should have gotten pricing", task_id); println!("Task {} took: {:?}", task_id, duration); } + + // Clean up + std::env::remove_var("GOOSE_CACHE_DIR"); + let _ = std::fs::remove_dir_all(&test_cache_dir); } From 11f33ffd84d2aba2f480eb92a40ce6f97892311f Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 17 Jul 2025 16:13:11 +1000 Subject: [PATCH 40/43] fixed merge conflicts --- .../goose-cli/src/recipes/extract_from_cli.rs | 1 - crates/goose-cli/src/session/mod.rs | 2 - .../src/session/task_execution_display/mod.rs | 4 +- crates/goose/src/agents/agent.rs | 4 +- .../agents/recipe_tools/dynamic_task_tools.rs | 15 +++---- .../agents/recipe_tools/sub_recipe_tools.rs | 4 +- .../recipe_tools/sub_recipe_tools/tests.rs | 1 - crates/goose/src/agents/sub_recipe_manager.rs | 2 +- .../agents/subagent_execution_tool/lib/mod.rs | 28 +++++++++--- .../src/agents/subagent_execution_tool/mod.rs | 1 + .../subagent_execute_task_tool.rs | 17 +++----- .../subagent_execution_tool/task_types.rs | 1 - .../agents/subagent_execution_tool/tasks.rs | 43 +++---------------- .../subagent_execution_tool/tasks_manager.rs | 2 +- crates/goose/src/recipe/mod.rs | 21 --------- 15 files changed, 50 insertions(+), 96 deletions(-) diff --git a/crates/goose-cli/src/recipes/extract_from_cli.rs b/crates/goose-cli/src/recipes/extract_from_cli.rs index 5f4c1b2d9bb6..56113b75dacd 100644 --- a/crates/goose-cli/src/recipes/extract_from_cli.rs +++ b/crates/goose-cli/src/recipes/extract_from_cli.rs @@ -31,7 +31,6 @@ pub fn extract_recipe_info_from_cli( let additional_sub_recipe = SubRecipe { path: recipe_file_path.to_string_lossy().to_string(), name, - timeout_in_seconds: None, values: None, sequential_when_repeated: true, }; diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index e459c6c5100a..a06165356590 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -23,8 +23,6 @@ use goose::permission::PermissionConfirmation; use goose::providers::base::Provider; pub use goose::session::Identifier; use goose::utils::safe_truncate; -use std::io::Write; -use task_execution_display::format_task_execution_notification; use anyhow::{Context, Result}; use completion::GooseCompleter; diff --git a/crates/goose-cli/src/session/task_execution_display/mod.rs b/crates/goose-cli/src/session/task_execution_display/mod.rs index ec6c41ff201b..b0b208ed546e 100644 --- a/crates/goose-cli/src/session/task_execution_display/mod.rs +++ b/crates/goose-cli/src/session/task_execution_display/mod.rs @@ -1,5 +1,5 @@ -use goose::agents::sub_recipe_execution_tool::lib::TaskStatus; -use goose::agents::sub_recipe_execution_tool::notification_events::{ +use goose::agents::subagent_execution_tool::lib::TaskStatus; +use goose::agents::subagent_execution_tool::notification_events::{ TaskExecutionNotificationEvent, TaskInfo, }; use serde_json::Value; diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 50bf5cdf6c84..ce4471fb3540 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -12,11 +12,11 @@ use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_ use crate::agents::recipe_tools::dynamic_task_tools::{ create_dynamic_task, create_dynamic_task_tool, DYNAMIC_TASK_TOOL_NAME_PREFIX, }; -use crate::agents::sub_recipe_execution_tool::tasks_manager::TasksManager; use crate::agents::sub_recipe_manager::SubRecipeManager; use crate::agents::subagent_execution_tool::subagent_execute_task_tool::{ self, SUBAGENT_EXECUTE_TASK_TOOL_NAME, }; +use crate::agents::subagent_execution_tool::tasks_manager::TasksManager; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; use crate::message::{push_message, Message}; use crate::permission::permission_judge::check_tool_permissions; @@ -308,7 +308,7 @@ impl Agent { TaskConfig::new(provider, Some(Arc::clone(&self.extension_manager)), mcp_tx); subagent_execute_task_tool::run_tasks(tool_call.arguments.clone(), task_config, &self.tasks_manager,).await } else if tool_call.name == DYNAMIC_TASK_TOOL_NAME_PREFIX { - create_dynamic_task(tool_call.arguments.clone()).await + create_dynamic_task(tool_call.arguments.clone(), &self.tasks_manager).await } else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { // Check if the tool is read_resource and handle it separately ToolCallResult::from( diff --git a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs index 6bbafe872815..70358e496365 100644 --- a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs +++ b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs @@ -6,6 +6,7 @@ use crate::agents::recipe_tools::sub_recipe_tools::{ EXECUTION_MODE_PARALLEL, EXECUTION_MODE_SEQUENTIAL, }; use crate::agents::subagent_execution_tool::task_types::Task; +use crate::agents::subagent_execution_tool::tasks_manager::TasksManager; use crate::agents::tool_execution::ToolCallResult; use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; use serde_json::{json, Value}; @@ -96,11 +97,6 @@ fn create_text_instruction_tasks_from_params(task_params: &[Value]) -> Vec .unwrap_or("") .to_string(); - let timeout_seconds = task_param - .get("timeout_seconds") - .and_then(|v| v.as_u64()) - .unwrap_or(300); - let payload = json!({ "text_instruction": text_instruction }); @@ -108,7 +104,6 @@ fn create_text_instruction_tasks_from_params(task_params: &[Value]) -> Vec Task { id: uuid::Uuid::new_v4().to_string(), task_type: "text_instruction".to_string(), - timeout_in_seconds: Some(timeout_seconds), payload, } }) @@ -116,13 +111,14 @@ fn create_text_instruction_tasks_from_params(task_params: &[Value]) -> Vec } fn create_task_execution_payload(tasks: Vec, execution_mode: &str) -> Value { + let task_ids: Vec = tasks.iter().map(|task| task.id.clone()).collect(); json!({ - "tasks": tasks, + "task_ids": task_ids, "execution_mode": execution_mode }) } -pub async fn create_dynamic_task(params: Value) -> ToolCallResult { +pub async fn create_dynamic_task(params: Value, tasks_manager: &TasksManager) -> ToolCallResult { let task_params_array = extract_task_parameters(¶ms); if task_params_array.is_empty() { @@ -140,7 +136,7 @@ pub async fn create_dynamic_task(params: Value) -> ToolCallResult { EXECUTION_MODE_SEQUENTIAL }; - let task_execution_payload = create_task_execution_payload(tasks, execution_mode); + let task_execution_payload = create_task_execution_payload(tasks.clone(), execution_mode); let tasks_json = match serde_json::to_string(&task_execution_payload) { Ok(json) => json, @@ -151,5 +147,6 @@ pub async fn create_dynamic_task(params: Value) -> ToolCallResult { )))) } }; + tasks_manager.save_tasks(tasks.clone()).await; ToolCallResult::from(Ok(vec![Content::text(tasks_json)])) } diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index b668aaf07b08..10b39d5c38cd 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -5,8 +5,8 @@ use anyhow::Result; use mcp_core::tool::{Tool, ToolAnnotations}; use serde_json::{json, Map, Value}; -use crate::agents::sub_recipe_execution_tool::lib::{ExecutionMode, Task}; -use crate::agents::sub_recipe_execution_tool::tasks_manager::TasksManager; +use crate::agents::subagent_execution_tool::lib::{ExecutionMode, Task}; +use crate::agents::subagent_execution_tool::tasks_manager::TasksManager; use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubRecipe}; use super::param_utils::prepare_command_params; diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index efad2ce2574c..48c7a957e57d 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -11,7 +11,6 @@ mod tests { let sub_recipe = SubRecipe { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), - timeout_in_seconds: None, values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), sequential_when_repeated: true, }; diff --git a/crates/goose/src/agents/sub_recipe_manager.rs b/crates/goose/src/agents/sub_recipe_manager.rs index 33229b97ecef..891c3c9b0e2b 100644 --- a/crates/goose/src/agents/sub_recipe_manager.rs +++ b/crates/goose/src/agents/sub_recipe_manager.rs @@ -7,7 +7,7 @@ use crate::{ recipe_tools::sub_recipe_tools::{ create_sub_recipe_task, create_sub_recipe_task_tool, SUB_RECIPE_TASK_TOOL_NAME_PREFIX, }, - sub_recipe_execution_tool::tasks_manager::TasksManager, + subagent_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult, }, recipe::SubRecipe, diff --git a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs index aba50c087fa9..0d186e59c36c 100644 --- a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs @@ -1,6 +1,6 @@ -use crate::agents::subagent_execution_tool::executor::{ +use crate::agents::subagent_execution_tool::{executor::{ execute_single_task, execute_tasks_in_parallel, -}; +}, tasks_manager::TasksManager}; pub use crate::agents::subagent_execution_tool::task_types::{ ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; @@ -14,10 +14,28 @@ pub async fn execute_tasks( execution_mode: ExecutionMode, notifier: mpsc::Sender, task_config: TaskConfig, + tasks_manager: &TasksManager, ) -> Result { - let tasks: Vec = - serde_json::from_value(input.get("tasks").ok_or("Missing tasks field")?.clone()) - .map_err(|e| format!("Failed to parse tasks: {}", e))?; + let task_ids: Vec = serde_json::from_value( + input + .get("task_ids") + .ok_or("Missing task_ids field")? + .clone(), + ) + .map_err(|e| format!("Failed to parse task_ids: {}", e))?; + + let mut tasks = Vec::new(); + for task_id in &task_ids { + match tasks_manager.get_task(task_id).await { + Some(task) => tasks.push(task), + None => { + return Err(format!( + "Task with ID '{}' not found in TasksManager", + task_id + )) + } + } + } let task_count = tasks.len(); match execution_mode { diff --git a/crates/goose/src/agents/subagent_execution_tool/mod.rs b/crates/goose/src/agents/subagent_execution_tool/mod.rs index 9dab0001d031..2226e2d7e58c 100644 --- a/crates/goose/src/agents/subagent_execution_tool/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/mod.rs @@ -5,5 +5,6 @@ pub mod subagent_execute_task_tool; pub mod task_execution_tracker; pub mod task_types; pub mod tasks; +pub mod tasks_manager; pub mod utils; pub mod workers; diff --git a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs index bc4960d9dec2..f3f1a8bf5cfe 100644 --- a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs +++ b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs @@ -3,9 +3,9 @@ use serde_json::Value; use crate::agents::subagent_task_config::TaskConfig; use crate::agents::{ - sub_recipe_execution_tool::lib::execute_tasks, - sub_recipe_execution_tool::task_types::ExecutionMode, - sub_recipe_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult, + subagent_execution_tool::lib::execute_tasks, + subagent_execution_tool::task_types::ExecutionMode, + subagent_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult, }; use mcp_core::protocol::JsonRpcMessage; use tokio::sync::mpsc; @@ -14,7 +14,7 @@ use tokio_stream; pub const SUBAGENT_EXECUTE_TASK_TOOL_NAME: &str = "subagent__execute_task"; pub fn create_subagent_execute_task_tool() -> Tool { Tool::new( - SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, + SUBAGENT_EXECUTE_TASK_TOOL_NAME, "Only use the subagent__execute_task tool when you execute sub recipe task or dynamic task. EXECUTION STRATEGY DECISION: 1. If the tasks are created with execution_mode, use the execution_mode. @@ -29,11 +29,7 @@ User Intent Based: - User: 'get weather and tell me a joke' → Sequential (2 separate tool calls, 1 task each) - User: 'get weather and joke in parallel' → Parallel (1 tool call with task array) - User: 'run these simultaneously' → Parallel (1 tool call with task array) -- User: 'do task A then task B' → Sequential (2 separate tool calls) - -Pre-created Task Based: -- subrecipe__create_task_weather returns execution_mode: 'parallel' → Use parallel execution -- subrecipe__create_task_weather returns execution_mode: 'sequential' → Use sequential execution", +- User: 'do task A then task B' → Sequential (2 separate tool calls)", serde_json::json!({ "type": "object", "properties": { @@ -63,7 +59,7 @@ Pre-created Task Based: ) } -pub async fn run_tasks(execute_data: Value, tasks_manager: &TasksManager) -> ToolCallResult { +pub async fn run_tasks(execute_data: Value, task_config: TaskConfig, tasks_manager: &TasksManager) -> ToolCallResult { let (notification_tx, notification_rx) = mpsc::channel::(100); let tasks_manager_clone = tasks_manager.clone(); @@ -78,6 +74,7 @@ pub async fn run_tasks(execute_data: Value, tasks_manager: &TasksManager) -> Too execute_data, execution_mode, notification_tx, + task_config, &tasks_manager_clone, ) .await diff --git a/crates/goose/src/agents/subagent_execution_tool/task_types.rs b/crates/goose/src/agents/subagent_execution_tool/task_types.rs index 92a79318b5fc..270573d6ee58 100644 --- a/crates/goose/src/agents/subagent_execution_tool/task_types.rs +++ b/crates/goose/src/agents/subagent_execution_tool/task_types.rs @@ -18,7 +18,6 @@ pub enum ExecutionMode { pub struct Task { pub id: String, pub task_type: String, - pub timeout_in_seconds: Option, pub payload: Value, } diff --git a/crates/goose/src/agents/subagent_execution_tool/tasks.rs b/crates/goose/src/agents/subagent_execution_tool/tasks.rs index a34b8ebfe57b..000dd917baae 100644 --- a/crates/goose/src/agents/subagent_execution_tool/tasks.rs +++ b/crates/goose/src/agents/subagent_execution_tool/tasks.rs @@ -1,62 +1,31 @@ use serde_json::Value; use std::process::Stdio; use std::sync::Arc; -use std::time::Duration; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; -use tokio::time::timeout; use crate::agents::subagent_execution_tool::task_execution_tracker::TaskExecutionTracker; use crate::agents::subagent_execution_tool::task_types::{Task, TaskResult, TaskStatus}; use crate::agents::subagent_handler::run_complete_subagent_task; use crate::agents::subagent_task_config::TaskConfig; -const DEFAULT_TASK_TIMEOUT_SECONDS: u64 = 300; - pub async fn process_task( task: &Task, task_execution_tracker: Arc, task_config: TaskConfig, ) -> TaskResult { - let timeout_in_seconds = task - .timeout_in_seconds - .unwrap_or(DEFAULT_TASK_TIMEOUT_SECONDS); - let task_clone = task.clone(); - let timeout_duration = Duration::from_secs(timeout_in_seconds); - - let task_execution_tracker_clone = task_execution_tracker.clone(); - match timeout( - timeout_duration, - get_task_result(task_clone, task_execution_tracker, task_config), - ) - .await - { - Ok(Ok(data)) => TaskResult { + match get_task_result(task.clone(), task_execution_tracker, task_config).await { + Ok(data) => TaskResult { task_id: task.id.clone(), status: TaskStatus::Completed, data: Some(data), error: None, }, - Ok(Err(error)) => TaskResult { + Err(error) => TaskResult { task_id: task.id.clone(), status: TaskStatus::Failed, data: None, error: Some(error), - }, - Err(_) => { - let current_output = task_execution_tracker_clone - .get_current_output(&task.id) - .await - .unwrap_or_default(); - - TaskResult { - task_id: task.id.clone(), - status: TaskStatus::Failed, - data: Some(serde_json::json!({ - "partial_output": current_output - })), - error: Some(format!("Task timed out after {}s", timeout_in_seconds)), - } } } } @@ -132,8 +101,7 @@ async fn handle_text_instruction_task( fn build_command(task: &Task) -> Result<(Command, String), String> { let task_error = |field: &str| format!("Task {}: Missing {}", task.id, field); - let mut output_identifier = task.id.clone(); - let mut command = if task.task_type == "sub_recipe" { + let (mut command, output_identifier) = if task.task_type == "sub_recipe" { let sub_recipe_name = task .get_sub_recipe_name() .ok_or_else(|| task_error("sub_recipe name"))?; @@ -144,7 +112,6 @@ fn build_command(task: &Task) -> Result<(Command, String), String> { .get_command_parameters() .ok_or_else(|| task_error("command_parameters"))?; - output_identifier = format!("sub-recipe {}", sub_recipe_name); let mut cmd = Command::new("goose"); cmd.arg("run").arg("--recipe").arg(path).arg("--no-session"); @@ -154,7 +121,7 @@ fn build_command(task: &Task) -> Result<(Command, String), String> { cmd.arg("--params") .arg(format!("{}={}", key_str, value_str)); } - cmd + (cmd, format!("sub-recipe {}", sub_recipe_name)) } else { // This branch should not be reached for text_instruction tasks anymore // as they are handled in handle_text_instruction_task diff --git a/crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs b/crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs index 433478be1f29..334379fa4ef5 100644 --- a/crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs +++ b/crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; -use crate::agents::sub_recipe_execution_tool::task_types::Task; +use crate::agents::subagent_execution_tool::task_types::Task; #[derive(Debug, Clone)] pub struct TasksManager { diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index b0146dcefda6..37d01ec260eb 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -135,32 +135,11 @@ pub struct Response { pub struct SubRecipe { pub name: String, pub path: String, - pub timeout_in_seconds: Option, - #[serde(default, deserialize_with = "deserialize_value_map_as_string")] - pub values: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub executions: Option, -} -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Execution { - #[serde(default)] - pub parallel: bool, - #[serde(skip_serializing_if = "Option::is_none")] - pub runs: Option>, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct ExecutionRun { #[serde(default, deserialize_with = "deserialize_value_map_as_string")] pub values: Option>, #[serde(default)] pub sequential_when_repeated: bool, } -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Execution { - #[serde(default)] - pub parallel: bool, -} fn deserialize_value_map_as_string<'de, D>( deserializer: D, From e4d822a12d4d9344d3792c998263257daf3534ac Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 17 Jul 2025 16:16:39 +1000 Subject: [PATCH 41/43] fixed fmt and clippy --- crates/goose/src/agents/agent.rs | 7 ++++++- .../src/agents/recipe_tools/dynamic_task_tools.rs | 10 ++++------ .../src/agents/recipe_tools/sub_recipe_tools/tests.rs | 4 ++-- .../src/agents/subagent_execution_tool/lib/mod.rs | 7 ++++--- .../subagent_execute_task_tool.rs | 6 +++++- .../goose/src/agents/subagent_execution_tool/tasks.rs | 2 +- crates/goose/tests/pricing_integration_test.rs | 10 +++++----- 7 files changed, 27 insertions(+), 19 deletions(-) diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index ce4471fb3540..2cf0370865e1 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -306,7 +306,12 @@ impl Agent { let task_config = TaskConfig::new(provider, Some(Arc::clone(&self.extension_manager)), mcp_tx); - subagent_execute_task_tool::run_tasks(tool_call.arguments.clone(), task_config, &self.tasks_manager,).await + subagent_execute_task_tool::run_tasks( + tool_call.arguments.clone(), + task_config, + &self.tasks_manager, + ) + .await } else if tool_call.name == DYNAMIC_TASK_TOOL_NAME_PREFIX { create_dynamic_task(tool_call.arguments.clone(), &self.tasks_manager).await } else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { diff --git a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs index 70358e496365..8b436ed4bba8 100644 --- a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs +++ b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs @@ -15,9 +15,8 @@ pub const DYNAMIC_TASK_TOOL_NAME_PREFIX: &str = "dynamic_task__create_task"; pub fn create_dynamic_task_tool() -> Tool { Tool::new( - format!("{}", DYNAMIC_TASK_TOOL_NAME_PREFIX), - format!( - "Use this tool to create one or more dynamic tasks from a shared text instruction and varying parameters.\ + DYNAMIC_TASK_TOOL_NAME_PREFIX.to_string(), + "Use this tool to create one or more dynamic tasks from a shared text instruction and varying parameters.\ How it works: - Provide a single text instruction - Use the 'task_parameters' field to pass an array of parameter sets @@ -40,8 +39,7 @@ pub fn create_dynamic_task_tool() -> Tool { timeout_seconds: 300 text_instruction: Get weather for San Francisco. timeout_seconds: 300 - " - ), + ".to_string(), json!({ "type": "object", "properties": { @@ -70,7 +68,7 @@ pub fn create_dynamic_task_tool() -> Tool { } }), Some(ToolAnnotations { - title: Some(format!("Dynamic Task Creation")), + title: Some("Dynamic Task Creation".to_string()), read_only_hint: false, destructive_hint: true, idempotent_hint: false, diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index 48c7a957e57d..0b682b0b649b 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -71,7 +71,7 @@ mod tests { prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - let result = get_input_schema(&sub_recipe).unwrap(); + let result = get_input_schema(&sub_recipe).unwrap(); verify_task_parameters( result, @@ -113,7 +113,7 @@ mod tests { prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); sub_recipe.values = None; - let result = get_input_schema(&sub_recipe).unwrap(); + let result = get_input_schema(&sub_recipe).unwrap(); verify_task_parameters( result, diff --git a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs index 0d186e59c36c..4cafb64ee712 100644 --- a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs @@ -1,9 +1,10 @@ -use crate::agents::subagent_execution_tool::{executor::{ - execute_single_task, execute_tasks_in_parallel, -}, tasks_manager::TasksManager}; pub use crate::agents::subagent_execution_tool::task_types::{ ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; +use crate::agents::subagent_execution_tool::{ + executor::{execute_single_task, execute_tasks_in_parallel}, + tasks_manager::TasksManager, +}; use crate::agents::subagent_task_config::TaskConfig; use mcp_core::protocol::JsonRpcMessage; use serde_json::Value; diff --git a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs index f3f1a8bf5cfe..bc5063317d35 100644 --- a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs +++ b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs @@ -59,7 +59,11 @@ User Intent Based: ) } -pub async fn run_tasks(execute_data: Value, task_config: TaskConfig, tasks_manager: &TasksManager) -> ToolCallResult { +pub async fn run_tasks( + execute_data: Value, + task_config: TaskConfig, + tasks_manager: &TasksManager, +) -> ToolCallResult { let (notification_tx, notification_rx) = mpsc::channel::(100); let tasks_manager_clone = tasks_manager.clone(); diff --git a/crates/goose/src/agents/subagent_execution_tool/tasks.rs b/crates/goose/src/agents/subagent_execution_tool/tasks.rs index 000dd917baae..7ed93245d026 100644 --- a/crates/goose/src/agents/subagent_execution_tool/tasks.rs +++ b/crates/goose/src/agents/subagent_execution_tool/tasks.rs @@ -26,7 +26,7 @@ pub async fn process_task( status: TaskStatus::Failed, data: None, error: Some(error), - } + }, } } diff --git a/crates/goose/tests/pricing_integration_test.rs b/crates/goose/tests/pricing_integration_test.rs index 15a77abb0002..083f96daf74d 100644 --- a/crates/goose/tests/pricing_integration_test.rs +++ b/crates/goose/tests/pricing_integration_test.rs @@ -6,7 +6,7 @@ async fn test_pricing_cache_performance() { // Use a unique cache directory for this test to avoid conflicts let test_cache_dir = format!("/tmp/goose_test_cache_perf_{}", std::process::id()); std::env::set_var("GOOSE_CACHE_DIR", &test_cache_dir); - + // Initialize the cache let start = Instant::now(); initialize_pricing_cache() @@ -69,7 +69,7 @@ async fn test_pricing_cache_performance() { first_fetch_duration, second_fetch_duration ); - + // Clean up std::env::remove_var("GOOSE_CACHE_DIR"); let _ = std::fs::remove_dir_all(&test_cache_dir); @@ -113,7 +113,7 @@ async fn test_model_not_in_openrouter() { // Use a unique cache directory for this test to avoid conflicts let test_cache_dir = format!("/tmp/goose_test_cache_model_{}", std::process::id()); std::env::set_var("GOOSE_CACHE_DIR", &test_cache_dir); - + initialize_pricing_cache() .await .expect("Failed to initialize pricing cache"); @@ -124,7 +124,7 @@ async fn test_model_not_in_openrouter() { pricing.is_none(), "Should return None for non-existent model" ); - + // Clean up std::env::remove_var("GOOSE_CACHE_DIR"); let _ = std::fs::remove_dir_all(&test_cache_dir); @@ -161,7 +161,7 @@ async fn test_concurrent_access() { assert!(has_pricing, "Task {} should have gotten pricing", task_id); println!("Task {} took: {:?}", task_id, duration); } - + // Clean up std::env::remove_var("GOOSE_CACHE_DIR"); let _ = std::fs::remove_dir_all(&test_cache_dir); From aa25ef4fa8e7ec0eed4cc4f393d4ad871939cf6b Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 17 Jul 2025 16:56:59 +1000 Subject: [PATCH 42/43] more fix on merge conflicts --- crates/goose-cli/src/session/builder.rs | 1 - .../agents/recipe_tools/dynamic_task_tools.rs | 11 ++++----- .../agents/recipe_tools/sub_recipe_tools.rs | 2 -- .../agents/subagent_execution_tool/lib/mod.rs | 18 +++++++++++---- .../subagent_execute_task_tool.rs | 4 ++-- .../task_execution_tracker.rs | 23 ++++++++++++++++++- .../subagent_execution_tool/task_types.rs | 6 +++++ .../subagent_execution_tool/utils/mod.rs | 3 +++ .../subagent_execution_tool/utils/tests.rs | 9 ++------ 9 files changed, 53 insertions(+), 24 deletions(-) diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index 5e780d661c13..588325c37368 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -204,7 +204,6 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { // Create the agent let agent: Agent = Agent::new(); - // Sub-recipes if let Some(sub_recipes) = session_config.sub_recipes { agent.add_sub_recipes(sub_recipes).await; } diff --git a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs index 8b436ed4bba8..449ea04d2a9c 100644 --- a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs +++ b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs @@ -2,11 +2,8 @@ // Module: Dynamic Task Tools // Handles creation of tasks dynamically without sub-recipes // ======================================= -use crate::agents::recipe_tools::sub_recipe_tools::{ - EXECUTION_MODE_PARALLEL, EXECUTION_MODE_SEQUENTIAL, -}; -use crate::agents::subagent_execution_tool::task_types::Task; use crate::agents::subagent_execution_tool::tasks_manager::TasksManager; +use crate::agents::subagent_execution_tool::{lib::ExecutionMode, task_types::Task}; use crate::agents::tool_execution::ToolCallResult; use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; use serde_json::{json, Value}; @@ -108,7 +105,7 @@ fn create_text_instruction_tasks_from_params(task_params: &[Value]) -> Vec .collect() } -fn create_task_execution_payload(tasks: Vec, execution_mode: &str) -> Value { +fn create_task_execution_payload(tasks: Vec, execution_mode: ExecutionMode) -> Value { let task_ids: Vec = tasks.iter().map(|task| task.id.clone()).collect(); json!({ "task_ids": task_ids, @@ -129,9 +126,9 @@ pub async fn create_dynamic_task(params: Value, tasks_manager: &TasksManager) -> // Use parallel execution if there are multiple tasks, sequential for single task let execution_mode = if tasks.len() > 1 { - EXECUTION_MODE_PARALLEL + ExecutionMode::Parallel } else { - EXECUTION_MODE_SEQUENTIAL + ExecutionMode::Sequential }; let task_execution_payload = create_task_execution_payload(tasks.clone(), execution_mode); diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 10b39d5c38cd..66b89ea39db9 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -12,8 +12,6 @@ use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubReci use super::param_utils::prepare_command_params; pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; -pub const EXECUTION_MODE_PARALLEL: &str = "parallel"; -pub const EXECUTION_MODE_SEQUENTIAL: &str = "sequential"; pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { let input_schema = get_input_schema(sub_recipe).unwrap(); diff --git a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs index 4cafb64ee712..8051ffa2c3f3 100644 --- a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs @@ -7,7 +7,7 @@ use crate::agents::subagent_execution_tool::{ }; use crate::agents::subagent_task_config::TaskConfig; use mcp_core::protocol::JsonRpcMessage; -use serde_json::Value; +use serde_json::{json, Value}; use tokio::sync::mpsc; pub async fn execute_tasks( @@ -49,9 +49,19 @@ pub async fn execute_tasks( } } ExecutionMode::Parallel => { - let response: ExecutionResponse = - execute_tasks_in_parallel(tasks, notifier, task_config).await; - handle_response(response) + if tasks.iter().any(|task| task.get_sequential_when_repeated()) { + Ok(json!( + { + "execution_mode": ExecutionMode::Sequential, + "task_ids": task_ids, + "results": ["the tasks should be executed sequentially, no matter how user requests it. Please use the subrecipe__execute_task tool to execute the tasks sequentially."] + } + )) + } else { + let response: ExecutionResponse = + execute_tasks_in_parallel(tasks, notifier.clone()).await; + handle_response(response) + } } } } diff --git a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs index bc5063317d35..73e3fa124a8e 100644 --- a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs +++ b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs @@ -27,7 +27,7 @@ IMPLEMENTATION: EXAMPLES: User Intent Based: - User: 'get weather and tell me a joke' → Sequential (2 separate tool calls, 1 task each) -- User: 'get weather and joke in parallel' → Parallel (1 tool call with task array) +- User: 'get weather and joke in parallel' → Parallel (1 tool call with array of 2 tasks) - User: 'run these simultaneously' → Parallel (1 tool call with task array) - User: 'do task A then task B' → Sequential (2 separate tool calls)", serde_json::json!({ @@ -37,7 +37,7 @@ User Intent Based: "type": "string", "enum": ["sequential", "parallel"], "default": "sequential", - "description": "Execution strategy for multiple tasks. For pre-created tasks, respect the execution_mode from task creation. For user intent, use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'." + "description": "Execution strategy for multiple tasks. Use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'." }, "task_ids": { "type": "array", diff --git a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs index 957ef2983d50..c720459e01ae 100644 --- a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs @@ -123,12 +123,33 @@ impl TaskExecutionTracker { .map(|task_info| task_info.current_output.clone()) } + async fn format_line(&self, task_info: Option<&TaskInfo>, line: &str) -> String { + if let Some(task_info) = task_info { + let task_name = get_task_name(task_info); + let task_type = task_info.task.task_type.clone(); + let metadata = format_task_metadata(task_info); + + if metadata.is_empty() { + format!("[{} ({})] {}", task_name, task_type, line) + } else { + format!("[{} ({}) {}] {}", task_name, task_type, metadata, line) + } + } else { + line.to_string() + } + } + pub async fn send_live_output(&self, task_id: &str, line: &str) { match self.display_mode { DisplayMode::SingleTaskOutput => { + let tasks = self.tasks.read().await; + let task_info = tasks.get(task_id); + + let formatted_line = self.format_line(task_info, line).await; + drop(tasks); let event = TaskExecutionNotificationEvent::line_output( task_id.to_string(), - line.to_string(), + formatted_line, ); if let Err(e) = diff --git a/crates/goose/src/agents/subagent_execution_tool/task_types.rs b/crates/goose/src/agents/subagent_execution_tool/task_types.rs index 270573d6ee58..796491f624f2 100644 --- a/crates/goose/src/agents/subagent_execution_tool/task_types.rs +++ b/crates/goose/src/agents/subagent_execution_tool/task_types.rs @@ -34,6 +34,12 @@ impl Task { .and_then(|cp| cp.as_object()) } + pub fn get_sequential_when_repeated(&self) -> bool { + self.get_sub_recipe() + .and_then(|sr| sr.get("sequential_when_repeated").and_then(|v| v.as_bool())) + .unwrap_or_default() + } + pub fn get_sub_recipe_name(&self) -> Option<&str> { self.get_sub_recipe() .and_then(|sr| sr.get("name")) diff --git a/crates/goose/src/agents/subagent_execution_tool/utils/mod.rs b/crates/goose/src/agents/subagent_execution_tool/utils/mod.rs index 2f6791ff8278..5d75675283d3 100644 --- a/crates/goose/src/agents/subagent_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/utils/mod.rs @@ -22,3 +22,6 @@ pub fn count_by_status(tasks: &HashMap) -> (usize, usize, usiz ); (total, pending, running, completed, failed) } + +#[cfg(test)] +mod tests; diff --git a/crates/goose/src/agents/subagent_execution_tool/utils/tests.rs b/crates/goose/src/agents/subagent_execution_tool/utils/tests.rs index f799b699aaca..b4e7f757b420 100644 --- a/crates/goose/src/agents/subagent_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/subagent_execution_tool/utils/tests.rs @@ -1,5 +1,5 @@ -use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskInfo, TaskStatus}; -use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; +use crate::agents::subagent_execution_tool::task_types::{Task, TaskInfo, TaskStatus}; +use crate::agents::subagent_execution_tool::utils::{count_by_status, get_task_name}; use serde_json::json; use std::collections::HashMap; @@ -22,7 +22,6 @@ mod test_get_task_name { let sub_recipe_task = Task { id: "task_1".to_string(), task_type: "sub_recipe".to_string(), - timeout_in_seconds: None, payload: json!({ "sub_recipe": { "name": "my_recipe", @@ -40,7 +39,6 @@ mod test_get_task_name { fn falls_back_to_task_id_for_text_instruction() { let text_task = Task { id: "task_2".to_string(), - timeout_in_seconds: None, task_type: "text_instruction".to_string(), payload: json!({"text_instruction": "do something"}), }; @@ -55,7 +53,6 @@ mod test_get_task_name { let malformed_task = Task { id: "task_3".to_string(), task_type: "sub_recipe".to_string(), - timeout_in_seconds: None, payload: json!({ "sub_recipe": { "recipe_path": "/path/to/recipe" @@ -74,7 +71,6 @@ mod test_get_task_name { let malformed_task = Task { id: "task_4".to_string(), task_type: "sub_recipe".to_string(), - timeout_in_seconds: None, payload: json!({}), // missing "sub_recipe" field }; @@ -91,7 +87,6 @@ mod count_by_status { let task = Task { id: id.to_string(), task_type: "test".to_string(), - timeout_in_seconds: None, payload: json!({}), }; create_task_info_with_defaults(task, status) From 161c80935b3e9f8a526c7c85de1aa9f4b5a9cd86 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 17 Jul 2025 18:06:36 +1000 Subject: [PATCH 43/43] fix test compilation --- crates/goose-cli/src/session/task_execution_display/tests.rs | 2 +- crates/goose/src/agents/subagent_execution_tool/lib/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/goose-cli/src/session/task_execution_display/tests.rs b/crates/goose-cli/src/session/task_execution_display/tests.rs index fb53285080d3..725d161dff5b 100644 --- a/crates/goose-cli/src/session/task_execution_display/tests.rs +++ b/crates/goose-cli/src/session/task_execution_display/tests.rs @@ -1,5 +1,5 @@ use super::*; -use goose::agents::sub_recipe_execution_tool::notification_events::{ +use goose::agents::subagent_execution_tool::notification_events::{ FailedTaskInfo, TaskCompletionStats, TaskExecutionStats, }; use serde_json::json; diff --git a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs index 8051ffa2c3f3..81d728886eab 100644 --- a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs @@ -59,7 +59,7 @@ pub async fn execute_tasks( )) } else { let response: ExecutionResponse = - execute_tasks_in_parallel(tasks, notifier.clone()).await; + execute_tasks_in_parallel(tasks, notifier.clone(), task_config).await; handle_response(response) } }