From 75d3cd1a8aaced66754156ca58cb1b451f7a0d48 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Mon, 7 Jul 2025 13:10:24 +1000 Subject: [PATCH 01/34] initial version of run sub recipe multiple times --- .../goose-cli/src/recipes/extract_from_cli.rs | 1 + .../agents/recipe_tools/sub_recipe_tools.rs | 187 ++++++++++++++++++ .../recipe_tools/sub_recipe_tools/tests.rs | 74 +++++++ .../sub_recipe_execute_task_tool.rs | 22 ++- crates/goose/src/agents/sub_recipe_manager.rs | 14 +- crates/goose/src/recipe/mod.rs | 15 ++ 6 files changed, 298 insertions(+), 15 deletions(-) diff --git a/crates/goose-cli/src/recipes/extract_from_cli.rs b/crates/goose-cli/src/recipes/extract_from_cli.rs index 0550ee26fef4..84d578c0cae4 100644 --- a/crates/goose-cli/src/recipes/extract_from_cli.rs +++ b/crates/goose-cli/src/recipes/extract_from_cli.rs @@ -32,6 +32,7 @@ pub fn extract_recipe_info_from_cli( path: recipe_file_path.to_string_lossy().to_string(), name, values: None, + executions: None, }; all_sub_recipes.push(additional_sub_recipe); } diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 928cf8bd0845..7ef2e4737bc6 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::{collections::HashMap, fs}; use anyhow::Result; @@ -9,6 +10,7 @@ use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubReci pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; +#[allow(dead_code)] pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { let input_schema = get_input_schema(sub_recipe).unwrap(); Tool::new( @@ -25,6 +27,32 @@ pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { ) } +#[allow(dead_code)] +pub fn create_multiple_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { + let input_schema = get_input_schema_for_multiple_sub_recipe_task(sub_recipe).unwrap(); + Tool::new( + format!("{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name), + format!( + "Create one or more tasks to run the '{}' sub recipe. \ + Provide an array of parameter sets in the 'task_parameters' field:\n\ + - For a single task: provide an array with one parameter set\n\ + - For multiple tasks: provide an array with multiple parameter sets, each with different values\n\n\ + Each task will run the same sub recipe but with different parameter values. \ + This is useful when you need to execute the same sub recipe multiple times with varying inputs. \ + After creating the task list, pass it to the task executor to run all tasks.", + sub_recipe.name + ), + input_schema, + Some(ToolAnnotations { + title: Some(format!("create multiple sub recipe tasks for {}", sub_recipe.name)), + read_only_hint: false, + destructive_hint: true, + idempotent_hint: false, + open_world_hint: true, + }), + ) +} + fn get_sub_recipe_parameter_definition( sub_recipe: &SubRecipe, ) -> Result>> { @@ -34,6 +62,7 @@ fn get_sub_recipe_parameter_definition( Ok(recipe.parameters) } +#[allow(dead_code)] fn get_input_schema(sub_recipe: &SubRecipe) -> Result { let mut sub_recipe_params_map = HashMap::::new(); if let Some(params_with_value) = &sub_recipe.values { @@ -73,6 +102,74 @@ fn get_input_schema(sub_recipe: &SubRecipe) -> Result { } } +#[allow(dead_code)] +fn get_input_schema_for_multiple_sub_recipe_task(sub_recipe: &SubRecipe) -> Result { + let mut sub_recipe_values = HashSet::::new(); + if let Some(params_with_value) = &sub_recipe.values { + for param_name in params_with_value.keys() { + sub_recipe_values.insert(param_name.clone()); + } + } + + if let Some(runs) = sub_recipe.executions.as_ref().and_then(|e| e.runs.as_ref()) { + for run in runs { + if let Some(params_with_value) = &run.values { + for param_name in params_with_value.keys() { + sub_recipe_values.insert(param_name.clone()); + } + } + } + } + + let parameter_definition = get_sub_recipe_parameter_definition(sub_recipe)?; + + let mut param_properties = Map::new(); + let mut param_required = Vec::new(); + + if let Some(parameters) = parameter_definition { + for param in parameters { + if sub_recipe_values.contains(¶m.key.clone()) { + continue; + } + param_properties.insert( + param.key.clone(), + json!({ + "type": param.input_type.to_string(), + "description": param.description.clone(), + }), + ); + if !matches!(param.requirement, RecipeParameterRequirement::Optional) { + param_required.push(param.key); + } + } + } + + // Create the schema using only task_parameters for both single and multiple tasks + let mut properties = Map::new(); + if !param_properties.is_empty() { + properties.insert( + "task_parameters".to_string(), + json!({ + "type": "array", + "description": "Array of parameter sets for creating tasks. \ + For a single task, provide an array with one element. \ + For multiple tasks, provide an array with multiple elements, each with different parameter values. \ + If there is no parameter set, provide an empty array.", + "items": { + "type": "object", + "properties": param_properties, + "required": param_required + }, + }) + ); + } + Ok(json!({ + "type": "object", + "properties": properties, + })) +} + +#[allow(dead_code)] fn prepare_command_params( sub_recipe: &SubRecipe, params_from_tool_call: Value, @@ -94,6 +191,51 @@ fn prepare_command_params( Ok(sub_recipe_params) } +fn prepare_command_params_for_multiple_sub_recipe_task( + sub_recipe: &SubRecipe, + params_from_tool_call: Vec, +) -> Result>> { + let mut sub_recipe_params = HashMap::::new(); + if let Some(params_with_value) = &sub_recipe.values { + for (param_name, param_value) in params_with_value { + sub_recipe_params.insert(param_name.clone(), param_value.clone()); + } + } + let mut sub_recipe_run_params = Vec::>::new(); + if let Some(runs) = sub_recipe.executions.as_ref().and_then(|e| e.runs.as_ref()) { + for run in runs { + let mut sub_recipe_run_param = sub_recipe_params.clone(); + if let Some(params_with_value) = &run.values { + sub_recipe_run_param.extend(params_with_value.clone()); + } + sub_recipe_run_params.push(sub_recipe_run_param); + } + } + println!("===== sub_recipe_run_params: {:?}", sub_recipe_run_params); + println!("===== params_from_tool_call: {:?}", params_from_tool_call); + if params_from_tool_call.is_empty() { + return Ok(sub_recipe_run_params); + } + if sub_recipe_run_params.len() > 0 && sub_recipe_run_params.len() != params_from_tool_call.len() + { + return Err(anyhow::anyhow!( + "The number of runs in the sub recipe does not match the number of task parameters" + )); + } + let mut sub_recipe_run_params = vec![sub_recipe_params.clone(); params_from_tool_call.len()]; + for (index, sub_recipe_task_param) in params_from_tool_call.iter().enumerate() { + if let Some(params_with_value) = sub_recipe_task_param.as_object() { + for (key, value) in params_with_value { + sub_recipe_run_params[index] + .entry(key.to_string()) + .or_insert_with(|| value.as_str().unwrap_or(&value.to_string()).to_string()); + } + } + } + Ok(sub_recipe_run_params) +} + +#[allow(dead_code)] pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result { let command_params = prepare_command_params(sub_recipe, params)?; let payload = json!({ @@ -113,5 +255,50 @@ pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Re Ok(task_json) } +pub async fn create_multiple_sub_recipe_tasks( + sub_recipe: &SubRecipe, + params: Value, +) -> Result { + // Get the task_parameters array + let empty_vec = vec![]; + let task_params_array = params + .get("task_parameters") + .and_then(|v| v.as_array()) + .unwrap_or(&empty_vec); + let command_params = + prepare_command_params_for_multiple_sub_recipe_task(sub_recipe, task_params_array.clone())?; + let tasks = command_params + .iter() + .map(|task_command_param| { + let payload = json!({ + "sub_recipe": { + "name": sub_recipe.name.clone(), + "command_parameters": task_command_param, + "recipe_path": sub_recipe.path.clone(), + } + }); + Task { + id: uuid::Uuid::new_v4().to_string(), + task_type: "sub_recipe".to_string(), + payload, + } + }) + .collect::>(); + let is_parallel = sub_recipe + .executions + .as_ref() + .map(|e| e.parallel) + .unwrap_or(false); + let task_execution_payload = json!({ + "tasks": tasks, + "execution_mode": if is_parallel { "parallel" } else { "sequential" } + }); + println!("===== task_execution_payload: {:?}", task_execution_payload); + + let tasks_json = serde_json::to_string(&task_execution_payload) + .map_err(|e| anyhow::anyhow!("Failed to serialize task list: {}", e))?; + Ok(tasks_json) +} + #[cfg(test)] mod tests; diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index 11ce390a6b3b..4fa8574d5b13 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -9,6 +9,7 @@ mod tests { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), + executions: None, }; sub_recipe } @@ -42,6 +43,7 @@ mod tests { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), values: None, + executions: None, }; let params: HashMap = HashMap::new(); let params_value = serde_json::to_value(params).unwrap(); @@ -113,6 +115,7 @@ mod tests { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), values: None, + executions: None, }; let sub_recipe_file_content = r#"{ @@ -152,4 +155,75 @@ mod tests { assert_eq!(result["required"][0], "key1"); } } + + mod get_input_schema_for_multiple_sub_recipe_task_tests { + use crate::{ + agents::recipe_tools::sub_recipe_tools::get_input_schema_for_multiple_sub_recipe_task, + recipe::SubRecipe, + }; + + #[test] + fn test_get_input_schema_for_multiple_tasks() { + let sub_recipe_file_content = r#"{ + "version": "1.0.0", + "title": "Test Recipe", + "description": "A test recipe", + "prompt": "Test prompt", + "parameters": [ + { + "key": "param1", + "input_type": "string", + "requirement": "required", + "description": "A required parameter" + }, + { + "key": "param2", + "input_type": "number", + "requirement": "optional", + "description": "An optional parameter" + } + ] + }"#; + + let temp_dir = tempfile::tempdir().unwrap(); + let temp_file = temp_dir.path().join("test_sub_recipe.yaml"); + std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); + + let sub_recipe = SubRecipe { + name: "test_sub_recipe".to_string(), + path: temp_file.to_string_lossy().to_string(), + values: None, + executions: None, + }; + + let result = get_input_schema_for_multiple_sub_recipe_task(&sub_recipe).unwrap(); + + // Verify the schema structure + assert_eq!(result["type"], "object"); + assert!(result["properties"].is_object()); + + let properties = result["properties"].as_object().unwrap(); + assert_eq!(properties.len(), 1); + assert!(properties.contains_key("task_parameters")); + + let task_params = &properties["task_parameters"]; + assert_eq!(task_params["type"], "array"); + assert_eq!(task_params["minItems"], 1); + + let items = &task_params["items"]; + assert_eq!(items["type"], "object"); + + let item_properties = items["properties"].as_object().unwrap(); + assert_eq!(item_properties.len(), 2); + assert!(item_properties.contains_key("param1")); + assert!(item_properties.contains_key("param2")); + + assert_eq!(item_properties["param1"]["type"], "string"); + assert_eq!(item_properties["param2"]["type"], "number"); + + let required = items["required"].as_array().unwrap(); + assert_eq!(required.len(), 1); + assert_eq!(required[0], "param1"); + } + } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 46738b813b13..0e7e061fd4d1 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -10,19 +10,31 @@ pub fn create_sub_recipe_execute_task_tool() -> Tool { Tool::new( SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, "Only use this tool when you execute sub recipe task. -EXECUTION STRATEGY: -- DEFAULT: Execute tasks sequentially (one at a time) unless user explicitly requests parallel execution -- PARALLEL: Only when user explicitly uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently' + +EXECUTION STRATEGY DECISION: +1. PRE-CREATED TASKS: If tasks were created by subrecipe__create_task_* tools, check the execution_mode in the response: + - If execution_mode is 'parallel', use parallel execution + - If execution_mode is 'sequential', use sequential execution + - Always respect the execution_mode from task creation to maintain consistency + +2. USER INTENT: If creating tasks inline or user explicitly specifies: + - DEFAULT: Execute tasks sequentially unless user explicitly requests parallel execution + - PARALLEL: When user uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently' IMPLEMENTATION: - Sequential execution: Call this tool multiple times, passing exactly ONE task per call - Parallel execution: Call this tool once, passing an ARRAY of all tasks EXAMPLES: +User Intent Based: - User: 'get weather and tell me a joke' → Sequential (2 separate tool calls, 1 task each) - User: 'get weather and joke in parallel' → Parallel (1 tool call with array of 2 tasks) - User: 'run these simultaneously' → Parallel (1 tool call with task array) -- User: 'do task A then task B' → Sequential (2 separate tool calls)", +- User: 'do task A then task B' → Sequential (2 separate tool calls) + +Pre-created Task Based: +- subrecipe__create_task_weather returns execution_mode: 'parallel' → Use parallel execution +- subrecipe__create_task_weather returns execution_mode: 'sequential' → Use sequential execution", serde_json::json!({ "type": "object", "properties": { @@ -30,7 +42,7 @@ EXAMPLES: "type": "string", "enum": ["sequential", "parallel"], "default": "sequential", - "description": "Execution strategy for multiple tasks. Use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'." + "description": "Execution strategy for multiple tasks. For pre-created tasks, respect the execution_mode from task creation. For user intent, use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'." }, "tasks": { "type": "array", diff --git a/crates/goose/src/agents/sub_recipe_manager.rs b/crates/goose/src/agents/sub_recipe_manager.rs index 2441684b4b0e..76d7effb48ee 100644 --- a/crates/goose/src/agents/sub_recipe_manager.rs +++ b/crates/goose/src/agents/sub_recipe_manager.rs @@ -5,7 +5,8 @@ use std::collections::HashMap; use crate::{ agents::{ recipe_tools::sub_recipe_tools::{ - create_sub_recipe_task, create_sub_recipe_task_tool, SUB_RECIPE_TASK_TOOL_NAME_PREFIX, + create_multiple_sub_recipe_task_tool, create_multiple_sub_recipe_tasks, + SUB_RECIPE_TASK_TOOL_NAME_PREFIX, }, tool_execution::ToolCallResult, }, @@ -34,18 +35,12 @@ impl SubRecipeManager { pub fn add_sub_recipe_tools(&mut self, sub_recipes_to_add: Vec) { for sub_recipe in sub_recipes_to_add { - // let sub_recipe_key = format!( - // "{}_{}", - // SUB_RECIPE_TOOL_NAME_PREFIX, - // sub_recipe.name.clone() - // ); - // let tool = create_sub_recipe_tool(&sub_recipe); let sub_recipe_key = format!( "{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name.clone() ); - let tool = create_sub_recipe_task_tool(&sub_recipe); + let tool = create_multiple_sub_recipe_task_tool(&sub_recipe); self.sub_recipe_tools.insert(sub_recipe_key.clone(), tool); self.sub_recipes.insert(sub_recipe_key.clone(), sub_recipe); } @@ -111,8 +106,7 @@ impl SubRecipeManager { ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name)) })?; - - let output = create_sub_recipe_task(sub_recipe, params) + let output = create_multiple_sub_recipe_tasks(sub_recipe, params) .await .map_err(|e| { ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e)) diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 55ff144f781e..75c08a42fa16 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -137,6 +137,21 @@ pub struct SubRecipe { pub path: String, #[serde(default, deserialize_with = "deserialize_value_map_as_string")] pub values: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub executions: Option, +} +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Execution { + #[serde(default)] + pub parallel: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub runs: Option>, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ExecutionRun { + #[serde(default, deserialize_with = "deserialize_value_map_as_string")] + pub values: Option>, } fn deserialize_value_map_as_string<'de, D>( From cb0e6a5db20a54d858e0ac0571da473ccfff91f6 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Mon, 7 Jul 2025 17:53:52 +1000 Subject: [PATCH 02/34] added tests, rename functions --- .../agents/recipe_tools/sub_recipe_tools.rs | 321 ++++------- .../recipe_tools/sub_recipe_tools/tests.rs | 520 ++++++++++++------ crates/goose/src/agents/sub_recipe_manager.rs | 7 +- 3 files changed, 484 insertions(+), 364 deletions(-) diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 7ef2e4737bc6..e50ee588700f 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -10,26 +10,8 @@ use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubReci pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; -#[allow(dead_code)] pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { let input_schema = get_input_schema(sub_recipe).unwrap(); - Tool::new( - format!("{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name), - "Before running this sub recipe, you should first create a task with this tool and then pass the task to the task executor".to_string(), - input_schema, - Some(ToolAnnotations { - title: Some(format!("create sub recipe task {}", sub_recipe.name)), - read_only_hint: false, - destructive_hint: true, - idempotent_hint: false, - open_world_hint: true, - }), - ) -} - -#[allow(dead_code)] -pub fn create_multiple_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { - let input_schema = get_input_schema_for_multiple_sub_recipe_task(sub_recipe).unwrap(); Tool::new( format!("{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name), format!( @@ -53,6 +35,45 @@ pub fn create_multiple_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { ) } +pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result { + let empty_vec = vec![]; + let task_params_array = params + .get("task_parameters") + .and_then(|v| v.as_array()) + .unwrap_or(&empty_vec); + let command_params = prepare_command_params(sub_recipe, task_params_array.clone())?; + let tasks = command_params + .iter() + .map(|task_command_param| { + let payload = json!({ + "sub_recipe": { + "name": sub_recipe.name.clone(), + "command_parameters": task_command_param, + "recipe_path": sub_recipe.path.clone(), + } + }); + Task { + id: uuid::Uuid::new_v4().to_string(), + task_type: "sub_recipe".to_string(), + payload, + } + }) + .collect::>(); + let is_parallel = sub_recipe + .executions + .as_ref() + .map(|e| e.parallel) + .unwrap_or(false); + let task_execution_payload = json!({ + "tasks": tasks, + "execution_mode": if is_parallel { "parallel" } else { "sequential" } + }); + + let tasks_json = serde_json::to_string(&task_execution_payload) + .map_err(|e| anyhow::anyhow!("Failed to serialize task list: {}", e))?; + Ok(tasks_json) +} + fn get_sub_recipe_parameter_definition( sub_recipe: &SubRecipe, ) -> Result>> { @@ -62,89 +83,26 @@ fn get_sub_recipe_parameter_definition( Ok(recipe.parameters) } -#[allow(dead_code)] -fn get_input_schema(sub_recipe: &SubRecipe) -> Result { - let mut sub_recipe_params_map = HashMap::::new(); - if let Some(params_with_value) = &sub_recipe.values { - for (param_name, param_value) in params_with_value { - sub_recipe_params_map.insert(param_name.clone(), param_value.clone()); - } - } - let parameter_definition = get_sub_recipe_parameter_definition(sub_recipe)?; - if let Some(parameters) = parameter_definition { - let mut properties = Map::new(); - let mut required = Vec::new(); - for param in parameters { - if sub_recipe_params_map.contains_key(¶m.key) { - continue; - } - properties.insert( - param.key.clone(), - json!({ - "type": param.input_type.to_string(), - "description": param.description.clone(), - }), - ); - if !matches!(param.requirement, RecipeParameterRequirement::Optional) { - required.push(param.key); - } - } - Ok(json!({ - "type": "object", - "properties": properties, - "required": required - })) - } else { - Ok(json!({ - "type": "object", - "properties": {} - })) - } -} - -#[allow(dead_code)] -fn get_input_schema_for_multiple_sub_recipe_task(sub_recipe: &SubRecipe) -> Result { - let mut sub_recipe_values = HashSet::::new(); +fn get_params_with_values(sub_recipe: &SubRecipe) -> HashSet { + let mut sub_recipe_params_with_values = HashSet::::new(); if let Some(params_with_value) = &sub_recipe.values { for param_name in params_with_value.keys() { - sub_recipe_values.insert(param_name.clone()); + sub_recipe_params_with_values.insert(param_name.clone()); } } - if let Some(runs) = sub_recipe.executions.as_ref().and_then(|e| e.runs.as_ref()) { for run in runs { if let Some(params_with_value) = &run.values { for param_name in params_with_value.keys() { - sub_recipe_values.insert(param_name.clone()); + sub_recipe_params_with_values.insert(param_name.clone()); } } } } + sub_recipe_params_with_values +} - let parameter_definition = get_sub_recipe_parameter_definition(sub_recipe)?; - - let mut param_properties = Map::new(); - let mut param_required = Vec::new(); - - if let Some(parameters) = parameter_definition { - for param in parameters { - if sub_recipe_values.contains(¶m.key.clone()) { - continue; - } - param_properties.insert( - param.key.clone(), - json!({ - "type": param.input_type.to_string(), - "description": param.description.clone(), - }), - ); - if !matches!(param.requirement, RecipeParameterRequirement::Optional) { - param_required.push(param.key); - } - } - } - - // Create the schema using only task_parameters for both single and multiple tasks +fn create_input_schema(param_properties: Map, param_required: Vec) -> Value { let mut properties = Map::new(); if !param_properties.is_empty() { properties.insert( @@ -163,141 +121,106 @@ fn get_input_schema_for_multiple_sub_recipe_task(sub_recipe: &SubRecipe) -> Resu }) ); } - Ok(json!({ + json!({ "type": "object", "properties": properties, - })) + }) } -#[allow(dead_code)] -fn prepare_command_params( - sub_recipe: &SubRecipe, - params_from_tool_call: Value, -) -> Result> { - let mut sub_recipe_params = HashMap::::new(); - if let Some(params_with_value) = &sub_recipe.values { - for (param_name, param_value) in params_with_value { - sub_recipe_params.insert(param_name.clone(), param_value.clone()); - } - } - if let Some(params_map) = params_from_tool_call.as_object() { - for (key, value) in params_map { - sub_recipe_params.insert( - key.to_string(), - value.as_str().unwrap_or(&value.to_string()).to_string(), +fn get_input_schema(sub_recipe: &SubRecipe) -> Result { + let sub_recipe_params_with_values = get_params_with_values(sub_recipe); + + let parameter_definition = get_sub_recipe_parameter_definition(sub_recipe)?; + + let mut param_properties = Map::new(); + let mut param_required = Vec::new(); + + if let Some(parameters) = parameter_definition { + for param in parameters { + if sub_recipe_params_with_values.contains(¶m.key.clone()) { + continue; + } + param_properties.insert( + param.key.clone(), + json!({ + "type": param.input_type.to_string(), + "description": param.description.clone(), + }), ); + if !matches!(param.requirement, RecipeParameterRequirement::Optional) { + param_required.push(param.key); + } } } - Ok(sub_recipe_params) + Ok(create_input_schema(param_properties, param_required)) +} + +fn extract_run_params( + sub_recipe: &SubRecipe, +) -> (HashMap, Vec>) { + let base_params = sub_recipe.values.clone().unwrap_or_default(); + + let run_params = sub_recipe + .executions + .as_ref() + .and_then(|e| e.runs.as_ref()) + .map(|runs| { + runs.iter() + .map(|run| { + let mut params = base_params.clone(); + if let Some(run_values) = &run.values { + params.extend(run_values.clone()); + } + params + }) + .collect::>() + }) + .unwrap_or_default(); + (base_params, run_params) } -fn prepare_command_params_for_multiple_sub_recipe_task( +fn prepare_command_params( sub_recipe: &SubRecipe, params_from_tool_call: Vec, ) -> Result>> { - let mut sub_recipe_params = HashMap::::new(); - if let Some(params_with_value) = &sub_recipe.values { - for (param_name, param_value) in params_with_value { - sub_recipe_params.insert(param_name.clone(), param_value.clone()); - } - } - let mut sub_recipe_run_params = Vec::>::new(); - if let Some(runs) = sub_recipe.executions.as_ref().and_then(|e| e.runs.as_ref()) { - for run in runs { - let mut sub_recipe_run_param = sub_recipe_params.clone(); - if let Some(params_with_value) = &run.values { - sub_recipe_run_param.extend(params_with_value.clone()); - } - sub_recipe_run_params.push(sub_recipe_run_param); - } - } - println!("===== sub_recipe_run_params: {:?}", sub_recipe_run_params); - println!("===== params_from_tool_call: {:?}", params_from_tool_call); + let (base_params, run_params) = extract_run_params(sub_recipe); + if params_from_tool_call.is_empty() { - return Ok(sub_recipe_run_params); + return Ok(run_params); } - if sub_recipe_run_params.len() > 0 && sub_recipe_run_params.len() != params_from_tool_call.len() - { + + if !run_params.is_empty() && run_params.len() != params_from_tool_call.len() { return Err(anyhow::anyhow!( - "The number of runs in the sub recipe does not match the number of task parameters" + "The number of runs in the sub recipe ({}) does not match the number of task parameters ({})", + run_params.len(), + params_from_tool_call.len() )); } - let mut sub_recipe_run_params = vec![sub_recipe_params.clone(); params_from_tool_call.len()]; - for (index, sub_recipe_task_param) in params_from_tool_call.iter().enumerate() { - if let Some(params_with_value) = sub_recipe_task_param.as_object() { - for (key, value) in params_with_value { - sub_recipe_run_params[index] - .entry(key.to_string()) - .or_insert_with(|| value.as_str().unwrap_or(&value.to_string()).to_string()); - } - } - } - Ok(sub_recipe_run_params) -} -#[allow(dead_code)] -pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result { - let command_params = prepare_command_params(sub_recipe, params)?; - let payload = json!({ - "sub_recipe": { - "name": sub_recipe.name.clone(), - "command_parameters": command_params, - "recipe_path": sub_recipe.path.clone(), - } - }); - let task = Task { - id: uuid::Uuid::new_v4().to_string(), - task_type: "sub_recipe".to_string(), - payload, + let run_params_for_merging = if run_params.is_empty() { + vec![base_params; params_from_tool_call.len()] + } else { + run_params }; - let task_json = serde_json::to_string(&task) - .map_err(|e| anyhow::anyhow!("Failed to serialize Task: {}", e))?; - Ok(task_json) -} -pub async fn create_multiple_sub_recipe_tasks( - sub_recipe: &SubRecipe, - params: Value, -) -> Result { - // Get the task_parameters array - let empty_vec = vec![]; - let task_params_array = params - .get("task_parameters") - .and_then(|v| v.as_array()) - .unwrap_or(&empty_vec); - let command_params = - prepare_command_params_for_multiple_sub_recipe_task(sub_recipe, task_params_array.clone())?; - let tasks = command_params - .iter() - .map(|task_command_param| { - let payload = json!({ - "sub_recipe": { - "name": sub_recipe.name.clone(), - "command_parameters": task_command_param, - "recipe_path": sub_recipe.path.clone(), + let merged_params = params_from_tool_call + .into_iter() + .zip(run_params_for_merging) + .map(|(tool_param, mut run_param_map)| { + if let Some(param_obj) = tool_param.as_object() { + for (key, value) in param_obj { + let value_str = value + .as_str() + .map(String::from) + .unwrap_or_else(|| value.to_string()); + run_param_map.entry(key.clone()).or_insert(value_str); } - }); - Task { - id: uuid::Uuid::new_v4().to_string(), - task_type: "sub_recipe".to_string(), - payload, } + run_param_map }) - .collect::>(); - let is_parallel = sub_recipe - .executions - .as_ref() - .map(|e| e.parallel) - .unwrap_or(false); - let task_execution_payload = json!({ - "tasks": tasks, - "execution_mode": if is_parallel { "parallel" } else { "sequential" } - }); - println!("===== task_execution_payload: {:?}", task_execution_payload); + .collect(); - let tasks_json = serde_json::to_string(&task_execution_payload) - .map_err(|e| anyhow::anyhow!("Failed to serialize task list: {}", e))?; - Ok(tasks_json) + Ok(merged_params) } #[cfg(test)] diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index 4fa8574d5b13..fac6a095bbfe 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -2,9 +2,12 @@ mod tests { use std::collections::HashMap; - use crate::recipe::SubRecipe; + use crate::recipe::{Execution, ExecutionRun, SubRecipe}; + use serde_json::json; + use serde_json::Value; + use tempfile::TempDir; - fn setup_sub_recipe() -> SubRecipe { + fn setup_default_sub_recipe() -> SubRecipe { let sub_recipe = SubRecipe { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), @@ -13,58 +16,225 @@ mod tests { }; sub_recipe } - mod prepare_command_params_tests { - use std::collections::HashMap; - use crate::{ - agents::recipe_tools::sub_recipe_tools::{ - prepare_command_params, tests::tests::setup_sub_recipe, - }, - recipe::SubRecipe, - }; + fn create_execution_values(key: &str, values: Vec) -> Execution { + let runs = values + .iter() + .map(|value| ExecutionRun { + values: Some(HashMap::from([(key.to_string(), value.to_string())])), + }) + .collect(); + Execution { + parallel: true, + runs: Some(runs), + } + } - #[test] - fn test_prepare_command_params_basic() { - let mut params = HashMap::new(); - params.insert("key2".to_string(), "value2".to_string()); + mod prepare_command_params_tests { + use super::*; - let sub_recipe = setup_sub_recipe(); + use crate::agents::recipe_tools::sub_recipe_tools::{ + prepare_command_params, tests::tests::setup_default_sub_recipe, + }; - let params_value = serde_json::to_value(params).unwrap(); - let result = prepare_command_params(&sub_recipe, params_value).unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result.get("key1"), Some(&"value1".to_string())); - assert_eq!(result.get("key2"), Some(&"value2".to_string())); + mod without_execution_runs { + use super::*; + + #[test] + fn test_return_command_param() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "value2".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_command_param_when_value_override_passed_param_value() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "different_value".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_empty_command_param() { + let parameter_array = vec![]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!(result.len(), 0); + } } - #[test] - fn test_prepare_command_params_empty() { - let sub_recipe = SubRecipe { - name: "test_sub_recipe".to_string(), - path: "test_sub_recipe.yaml".to_string(), - values: None, - executions: None, - }; - let params: HashMap = HashMap::new(); - let params_value = serde_json::to_value(params).unwrap(); - let result = prepare_command_params(&sub_recipe, params_value).unwrap(); - assert_eq!(result.len(), 0); + mod with_execution_runs { + use super::*; + + #[test] + fn test_return_command_param() { + let parameter_array = vec![json!(HashMap::from([( + "key3".to_string(), + "value3".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = + Some(create_execution_values("key2", vec!["value2".to_string()])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ("key3".to_string(), "value3".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_command_param_when_all_values_from_tool_call_parameters() { + let parameter_array = vec![ + json!(HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()) + ])), + json!(HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()) + ])), + ]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + sub_recipe.executions = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ]), + ], + result + ); + } + + #[test] + fn test_return_command_param_when_all_from_values_in_sub_recipe() { + let parameter_array = vec![]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string(), "key2_value2".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ("key3".to_string(), "value3".to_string()), + ]) + ], + result + ); + } + + #[test] + fn test_throw_error_when_execution_runs_value_length_not_match_with_tool_call_parameters( + ) { + let parameter_array = vec![json!(HashMap::from([( + "key3".to_string(), + "value3".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string(), "key2_value2".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array); + + assert!(result.is_err()); + } } } - mod get_input_schema_tests { - use crate::{ - agents::recipe_tools::sub_recipe_tools::{ - get_input_schema, tests::tests::setup_sub_recipe, - }, - recipe::SubRecipe, - }; + mod get_input_schema { + use super::*; + use crate::agents::recipe_tools::sub_recipe_tools::get_input_schema; + + fn prepare_sub_recipe(sub_recipe_file_content: &str) -> (SubRecipe, TempDir) { + let mut sub_recipe = setup_default_sub_recipe(); + let temp_dir = tempfile::tempdir().unwrap(); + let temp_file = temp_dir.path().join(sub_recipe.path.clone()); + std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); + sub_recipe.path = temp_file.to_string_lossy().to_string(); + (sub_recipe, temp_dir) + } - #[test] - fn test_get_input_schema_with_parameters() { - let sub_recipe = setup_sub_recipe(); + fn verify_task_parameters(result: Value, expected_task_parameters_items: Value) { + let task_parameters = result + .get("properties") + .unwrap() + .as_object() + .unwrap() + .get("task_parameters") + .unwrap() + .as_object() + .unwrap(); + let task_parameters_items = task_parameters.get("items").unwrap(); + assert_eq!(&expected_task_parameters_items, task_parameters_items); + } + + mod without_execution_runs { + use super::*; - let sub_recipe_file_content = r#"{ + const SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS: &str = r#"{ "version": "1.0.0", "title": "Test Recipe", "description": "A test recipe", @@ -85,40 +255,75 @@ mod tests { ] }"#; - let temp_dir = tempfile::tempdir().unwrap(); - let temp_file = temp_dir.path().join("test_sub_recipe.yaml"); - std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); - - let mut sub_recipe = sub_recipe; - sub_recipe.path = temp_file.to_string_lossy().to_string(); - - let result = get_input_schema(&sub_recipe).unwrap(); - - // Verify the schema structure - assert_eq!(result["type"], "object"); - assert!(result["properties"].is_object()); - - let properties = result["properties"].as_object().unwrap(); - assert_eq!(properties.len(), 1); - - let key2_prop = &properties["key2"]; - assert_eq!(key2_prop["type"], "number"); - assert_eq!(key2_prop["description"], "An optional parameter"); - - let required = result["required"].as_array().unwrap(); - assert_eq!(required.len(), 0); + #[test] + fn test_with_one_param_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + + let result = get_input_schema(&sub_recipe).unwrap(); + + verify_task_parameters( + result, + json!({ + "type": "object", + "properties": { + "key2": { "type": "number", "description": "An optional parameter" } + }, + "required": [] + }), + ); + } + + #[test] + fn test_without_param_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])); + + let result = get_input_schema(&sub_recipe).unwrap(); + + assert_eq!( + None, + result + .get("properties") + .unwrap() + .as_object() + .unwrap() + .get("task_parameters") + ); + } + + #[test] + fn test_with_all_params_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); + sub_recipe.values = None; + + let result = get_input_schema(&sub_recipe).unwrap(); + + verify_task_parameters( + result, + json!({ + "type": "object", + "properties": { + "key1": { "type": "string", "description": "A test parameter" }, + "key2": { "type": "number", "description": "An optional parameter" } + }, + "required": ["key1"] + }), + ); + } } - #[test] - fn test_get_input_schema_no_parameters_values() { - let sub_recipe = SubRecipe { - name: "test_sub_recipe".to_string(), - path: "test_sub_recipe.yaml".to_string(), - values: None, - executions: None, - }; + mod execution_runs { + use super::*; - let sub_recipe_file_content = r#"{ + const SUB_RECIPE_FILE_CONTENT_WITH_THREE_PARAMS: &str = r#"{ "version": "1.0.0", "title": "Test Recipe", "description": "A test recipe", @@ -128,102 +333,95 @@ mod tests { "key": "key1", "input_type": "string", "requirement": "required", - "description": "A test parameter" - } - ] - }"#; - - let temp_dir = tempfile::tempdir().unwrap(); - let temp_file = temp_dir.path().join("test_sub_recipe.yaml"); - std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); - - let mut sub_recipe = sub_recipe; - sub_recipe.path = temp_file.to_string_lossy().to_string(); - - let result = get_input_schema(&sub_recipe).unwrap(); - - assert_eq!(result["type"], "object"); - assert!(result["properties"].is_object()); - - let properties = result["properties"].as_object().unwrap(); - assert_eq!(properties.len(), 1); - - let key1_prop = &properties["key1"]; - assert_eq!(key1_prop["type"], "string"); - assert_eq!(key1_prop["description"], "A test parameter"); - assert_eq!(result["required"].as_array().unwrap().len(), 1); - assert_eq!(result["required"][0], "key1"); - } - } - - mod get_input_schema_for_multiple_sub_recipe_task_tests { - use crate::{ - agents::recipe_tools::sub_recipe_tools::get_input_schema_for_multiple_sub_recipe_task, - recipe::SubRecipe, - }; - - #[test] - fn test_get_input_schema_for_multiple_tasks() { - let sub_recipe_file_content = r#"{ - "version": "1.0.0", - "title": "Test Recipe", - "description": "A test recipe", - "prompt": "Test prompt", - "parameters": [ - { - "key": "param1", - "input_type": "string", - "requirement": "required", - "description": "A required parameter" + "description": "A required string parameter" }, { - "key": "param2", + "key": "key2", "input_type": "number", "requirement": "optional", "description": "An optional parameter" + }, + { + "key": "key3", + "input_type": "date", + "requirement": "required", + "description": "A required date parameter" } ] }"#; - let temp_dir = tempfile::tempdir().unwrap(); - let temp_file = temp_dir.path().join("test_sub_recipe.yaml"); - std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); - - let sub_recipe = SubRecipe { - name: "test_sub_recipe".to_string(), - path: temp_file.to_string_lossy().to_string(), - values: None, - executions: None, - }; - - let result = get_input_schema_for_multiple_sub_recipe_task(&sub_recipe).unwrap(); - - // Verify the schema structure - assert_eq!(result["type"], "object"); - assert!(result["properties"].is_object()); - - let properties = result["properties"].as_object().unwrap(); - assert_eq!(properties.len(), 1); - assert!(properties.contains_key("task_parameters")); - - let task_params = &properties["task_parameters"]; - assert_eq!(task_params["type"], "array"); - assert_eq!(task_params["minItems"], 1); - - let items = &task_params["items"]; - assert_eq!(items["type"], "object"); - - let item_properties = items["properties"].as_object().unwrap(); - assert_eq!(item_properties.len(), 2); - assert!(item_properties.contains_key("param1")); - assert!(item_properties.contains_key("param2")); - - assert_eq!(item_properties["param1"]["type"], "string"); - assert_eq!(item_properties["param2"]["type"], "number"); - - let required = items["required"].as_array().unwrap(); - assert_eq!(required.len(), 1); - assert_eq!(required[0], "param1"); + #[test] + fn test_with_one_param_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_THREE_PARAMS); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value_1".to_string(), "key2_value_2".to_string()], + )); + + let result = get_input_schema(&sub_recipe).unwrap(); + + verify_task_parameters( + result, + json!({ + "type": "object", + "properties": { + "key3": { "type": "date", "description": "A required date parameter" } + }, + "required": ["key3"] + }), + ); + } + + #[test] + fn test_without_param_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_THREE_PARAMS); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value_1".to_string(), "key2_value_2".to_string()], + )); + + let result = get_input_schema(&sub_recipe).unwrap(); + + assert_eq!( + None, + result + .get("properties") + .unwrap() + .as_object() + .unwrap() + .get("task_parameters") + ); + } + + #[test] + fn test_with_all_params_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_THREE_PARAMS); + sub_recipe.values = None; + + let result = get_input_schema(&sub_recipe).unwrap(); + + verify_task_parameters( + result, + json!({ + "type": "object", + "properties": { + "key1": { "type": "string", "description": "A required string parameter" }, + "key2": { "type": "number", "description": "An optional parameter" }, + "key3": { "type": "date", "description": "A required date parameter" } + }, + "required": ["key1", "key3"] + }), + ); + } } } } diff --git a/crates/goose/src/agents/sub_recipe_manager.rs b/crates/goose/src/agents/sub_recipe_manager.rs index 76d7effb48ee..d759914a9d9e 100644 --- a/crates/goose/src/agents/sub_recipe_manager.rs +++ b/crates/goose/src/agents/sub_recipe_manager.rs @@ -5,8 +5,7 @@ use std::collections::HashMap; use crate::{ agents::{ recipe_tools::sub_recipe_tools::{ - create_multiple_sub_recipe_task_tool, create_multiple_sub_recipe_tasks, - SUB_RECIPE_TASK_TOOL_NAME_PREFIX, + create_sub_recipe_task, create_sub_recipe_task_tool, SUB_RECIPE_TASK_TOOL_NAME_PREFIX, }, tool_execution::ToolCallResult, }, @@ -40,7 +39,7 @@ impl SubRecipeManager { SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name.clone() ); - let tool = create_multiple_sub_recipe_task_tool(&sub_recipe); + let tool = create_sub_recipe_task_tool(&sub_recipe); self.sub_recipe_tools.insert(sub_recipe_key.clone(), tool); self.sub_recipes.insert(sub_recipe_key.clone(), sub_recipe); } @@ -106,7 +105,7 @@ impl SubRecipeManager { ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name)) })?; - let output = create_multiple_sub_recipe_tasks(sub_recipe, params) + let output = create_sub_recipe_task(sub_recipe, params) .await .map_err(|e| { ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e)) From 476651ffcb4418e320386b9bee25b7c460800683 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 8 Jul 2025 14:13:09 +1000 Subject: [PATCH 03/34] get sub recipe json output --- .../agents/sub_recipe_execution_tool/tasks.rs | 89 ++++++++++++------- 1 file changed, 57 insertions(+), 32 deletions(-) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index 4e4584aa0b34..b934a13344fc 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -7,12 +7,10 @@ use tokio::time::timeout; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult}; -// Process a single task based on its type pub async fn process_task(task: &Task, timeout_seconds: u64) -> TaskResult { let task_clone = task.clone(); let timeout_duration = Duration::from_secs(timeout_seconds); - // Execute with timeout match timeout(timeout_duration, execute_task(task_clone)).await { Ok(Ok(data)) => TaskResult { task_id: task.id.clone(), @@ -36,6 +34,17 @@ pub async fn process_task(task: &Task, timeout_seconds: u64) -> TaskResult { } async fn execute_task(task: Task) -> Result { + let (command, output_identifier) = build_command(&task)?; + let (stdout_output, stderr_output, success) = run_command(command, &output_identifier).await?; + + if success { + process_output(stdout_output) + } else { + Err(format!("Command failed:\n{}", stderr_output)) + } +} + +fn build_command(task: &Task) -> Result<(Command, String), String> { let mut output_identifier = task.id.clone(); let mut command = if task.task_type == "sub_recipe" { let sub_recipe = task.payload.get("sub_recipe").unwrap(); @@ -66,11 +75,12 @@ async fn execute_task(task: Task) -> Result { cmd }; - // Configure to capture stdout command.stdout(Stdio::piped()); command.stderr(Stdio::piped()); + Ok((command, output_identifier)) +} - // Spawn the child process +async fn run_command(mut command: Command, output_identifier: &str) -> Result<(String, String, bool), String> { let mut child = command .spawn() .map_err(|e| format!("Failed to spawn goose: {}", e))?; @@ -78,30 +88,8 @@ async fn execute_task(task: Task) -> Result { let stdout = child.stdout.take().expect("Failed to capture stdout"); let stderr = child.stderr.take().expect("Failed to capture stderr"); - let mut stdout_reader = BufReader::new(stdout).lines(); - let mut stderr_reader = BufReader::new(stderr).lines(); - - // Spawn background tasks to read from stdout and stderr - let output_identifier_clone = output_identifier.clone(); - let stdout_task = tokio::spawn(async move { - let mut buffer = String::new(); - while let Ok(Some(line)) = stdout_reader.next_line().await { - println!("[{}] {}", output_identifier_clone, line); - buffer.push_str(&line); - buffer.push('\n'); - } - buffer - }); - - let stderr_task = tokio::spawn(async move { - let mut buffer = String::new(); - while let Ok(Some(line)) = stderr_reader.next_line().await { - eprintln!("[stderr for {}] {}", output_identifier, line); - buffer.push_str(&line); - buffer.push('\n'); - } - buffer - }); + let stdout_task = spawn_output_reader(stdout, output_identifier, false); + let stderr_task = spawn_output_reader(stderr, output_identifier, true); let status = child .wait() @@ -111,9 +99,46 @@ async fn execute_task(task: Task) -> Result { let stdout_output = stdout_task.await.unwrap(); let stderr_output = stderr_task.await.unwrap(); - if status.success() { - Ok(Value::String(stdout_output)) - } else { - Err(format!("Command failed:\n{}", stderr_output)) + Ok((stdout_output, stderr_output, status.success())) +} + +fn spawn_output_reader( + reader: impl tokio::io::AsyncRead + Unpin + Send + 'static, + output_identifier: &str, + is_stderr: bool, +) -> tokio::task::JoinHandle { + let output_identifier = output_identifier.to_string(); + tokio::spawn(async move { + let mut buffer = String::new(); + let mut lines = BufReader::new(reader).lines(); + while let Ok(Some(line)) = lines.next_line().await { + if is_stderr { + eprintln!("[stderr for {}] {}", output_identifier, line); + } else { + println!("[{}] {}", output_identifier, line); + } + buffer.push_str(&line); + buffer.push('\n'); + } + buffer + }) +} + +fn process_output(stdout_output: String) -> Result { + let last_line = stdout_output + .lines() + .filter(|line| !line.trim().is_empty()) + .last() + .unwrap_or(""); + + if let (Some(start), Some(end)) = (last_line.find('{'), last_line.rfind('}')) { + if start < end { + let potential_json = &last_line[start..=end]; + + if serde_json::from_str::(potential_json).is_ok() { + return Ok(Value::String(potential_json.to_string())); + } + } } + Ok(Value::String(stdout_output)) } From b50d245abe98fc40a0efcbe123148fae231c6f91 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 8 Jul 2025 16:53:07 +1000 Subject: [PATCH 04/34] added showing parallel running dashboard --- .../sub_recipe_execution_tool/dashboard.rs | 192 +++++++++++++ .../sub_recipe_execution_tool/executor.rs | 223 +++++++++++---- .../agents/sub_recipe_execution_tool/lib.rs | 2 +- .../agents/sub_recipe_execution_tool/mod.rs | 2 + .../agents/sub_recipe_execution_tool/tasks.rs | 61 ++-- .../agents/sub_recipe_execution_tool/types.rs | 57 +++- .../sub_recipe_execution_tool/utils/mod.rs | 65 +++++ .../sub_recipe_execution_tool/utils/tests.rs | 262 ++++++++++++++++++ .../sub_recipe_execution_tool/workers.rs | 105 ++----- 9 files changed, 808 insertions(+), 161 deletions(-) create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs new file mode 100644 index 000000000000..68867bb8e5c8 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs @@ -0,0 +1,192 @@ +use std::collections::HashMap; +use std::io::{self, Write}; +use std::sync::Arc; +use tokio::sync::RwLock; +use tokio::time::{Duration, Instant}; + +use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskResult, TaskStatus}; +use crate::agents::sub_recipe_execution_tool::utils::{ + count_by_status, get_task_name, strip_ansi_codes, truncate_with_ellipsis, +}; + +pub struct TaskDashboard { + tasks: Arc>>, + last_display: Arc>, + last_refresh: Arc>, +} + +impl TaskDashboard { + pub fn new(tasks: Vec) -> Self { + let task_map = tasks + .into_iter() + .map(|task| { + let task_id = task.id.clone(); + ( + task_id, + TaskInfo { + task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }, + ) + }) + .collect(); + + Self { + tasks: Arc::new(RwLock::new(task_map)), + last_display: Arc::new(RwLock::new(String::new())), + last_refresh: Arc::new(RwLock::new(Instant::now())), + } + } + + pub async fn start_task(&self, task_id: &str) { + let mut tasks = self.tasks.write().await; + if let Some(task_info) = tasks.get_mut(task_id) { + task_info.status = TaskStatus::Running; + task_info.start_time = Some(Instant::now()); + } + drop(tasks); + self.refresh_display().await; + } + + pub async fn complete_task(&self, task_id: &str, result: TaskResult) { + let mut tasks = self.tasks.write().await; + if let Some(task_info) = tasks.get_mut(task_id) { + task_info.status = result.status.clone(); + task_info.end_time = Some(Instant::now()); + task_info.result = Some(result); + } + drop(tasks); + self.refresh_display().await; + } + + pub async fn update_task_output(&self, task_id: &str, output: &str) { + let mut tasks = self.tasks.write().await; + if let Some(task_info) = tasks.get_mut(task_id) { + // Keep only the last few lines to avoid overwhelming display + let lines: Vec<&str> = output.lines().collect(); + let recent_lines = if lines.len() > 2 { + &lines[lines.len() - 2..] + } else { + &lines + }; + + // Strip ANSI escape sequences to prevent color flashing + let clean_output = recent_lines.join("\n"); + task_info.current_output = strip_ansi_codes(&clean_output); + } + drop(tasks); + + // Throttle refreshes to avoid overwhelming the display (max 1 per second) + let now = Instant::now(); + let mut last_refresh = self.last_refresh.write().await; + if now.duration_since(*last_refresh) > Duration::from_millis(1000) { + *last_refresh = now; + drop(last_refresh); + self.refresh_display().await; + } + } + + pub async fn refresh_display(&self) { + let tasks = self.tasks.read().await; + let mut display = String::new(); + + // Clear screen and move to top + display.push_str("\x1b[2J\x1b[H"); + + // Title + display.push_str("🎯 Task Execution Dashboard\n"); + display.push_str("═══════════════════════════\n\n"); + + // Summary stats + let (total, pending, running, completed, failed) = count_by_status(&tasks); + + display.push_str(&format!("📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed\n\n", + total, pending, running, completed, failed)); + + // Task list + let mut task_list: Vec<_> = tasks.values().collect(); + task_list.sort_by_key(|t| &t.task.id); + + for task_info in task_list { + let status_icon = match task_info.status { + TaskStatus::Pending => "⏳", + TaskStatus::Running => "🏃", + TaskStatus::Completed => "✅", + TaskStatus::Failed => "❌", + }; + + let task_name = get_task_name(task_info); + + display.push_str(&format!( + "{} {} ({})\n", + status_icon, task_name, task_info.task.task_type + )); + + if let Some(start_time) = task_info.start_time { + let duration = if let Some(end_time) = task_info.end_time { + end_time.duration_since(start_time) + } else { + Instant::now().duration_since(start_time) + }; + display.push_str(&format!(" ⏱️ {:.1}s\n", duration.as_secs_f64())); + } + + if matches!(task_info.status, TaskStatus::Running) + && !task_info.current_output.is_empty() + { + let output_preview = truncate_with_ellipsis(&task_info.current_output, 100); + display.push_str(&format!(" 💬 {}\n", output_preview.replace('\n', " | "))); + } + + if let Some(error) = task_info.error() { + let error_preview = truncate_with_ellipsis(error, 80); + display.push_str(&format!(" ⚠️ {}\n", error_preview.replace('\n', " "))); + } + + display.push('\n'); + } + + // Only update display if it changed + let mut last_display = self.last_display.write().await; + if *last_display != display { + print!("{}", display); + io::stdout().flush().unwrap(); + *last_display = display; + } + } + + pub async fn show_final_summary(&self) { + let tasks = self.tasks.read().await; + + println!("\n🎉 Execution Complete!"); + println!("═══════════════════════"); + + let (total, _, _, completed, failed) = count_by_status(&tasks); + + println!("📊 Final Results:"); + println!(" Total Tasks: {}", total); + println!(" ✅ Completed: {}", completed); + println!(" ❌ Failed: {}", failed); + println!( + " 📈 Success Rate: {:.1}%", + (completed as f64 / total as f64) * 100.0 + ); + + if failed > 0 { + println!("\n❌ Failed Tasks:"); + for task_info in tasks.values() { + if matches!(task_info.status, TaskStatus::Failed) { + let task_name = get_task_name(task_info); + println!(" • {}", task_name); + if let Some(error) = task_info.error() { + println!(" Error: {}", error); + } + } + } + } + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs index b796d412984d..9d869666e4af 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs @@ -1,103 +1,208 @@ -use std::sync::atomic::{AtomicBool, AtomicUsize}; +use std::sync::atomic::AtomicUsize; use std::sync::Arc; use tokio::sync::mpsc; use tokio::time::Instant; +use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; use crate::agents::sub_recipe_execution_tool::lib::{ - Config, ExecutionResponse, ExecutionStats, Task, TaskResult, + Config, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; use crate::agents::sub_recipe_execution_tool::tasks::process_task; -use crate::agents::sub_recipe_execution_tool::workers::{run_scaler, spawn_worker, SharedState}; +use crate::agents::sub_recipe_execution_tool::workers::spawn_worker; + +const EXECUTION_STATUS_COMPLETED: &str = "completed"; pub async fn execute_single_task(task: &Task, config: Config) -> ExecutionResponse { let start_time = Instant::now(); let result = process_task(task, config.timeout_seconds).await; let execution_time = start_time.elapsed().as_millis(); - let completed = if result.status == "success" { 1 } else { 0 }; - let failed = if result.status == "failed" { 1 } else { 0 }; + let stats = calculate_stats(&[result.clone()], execution_time); ExecutionResponse { - status: "completed".to_string(), + status: EXECUTION_STATUS_COMPLETED.to_string(), results: vec![result], - stats: ExecutionStats { - total_tasks: 1, - completed, - failed, - execution_time_ms: execution_time, - }, + stats, } } -// Main parallel execution function -pub async fn parallel_execute(tasks: Vec, config: Config) -> ExecutionResponse { - let start_time = Instant::now(); - let task_count = tasks.len(); +fn calculate_stats(results: &[TaskResult], execution_time_ms: u128) -> ExecutionStats { + let completed = results + .iter() + .filter(|r| matches!(r.status, TaskStatus::Completed)) + .count(); + let failed = results + .iter() + .filter(|r| matches!(r.status, TaskStatus::Failed)) + .count(); + + ExecutionStats { + total_tasks: results.len(), + completed, + failed, + execution_time_ms, + } +} + +struct ExecutionContext { + tasks: Vec, + config: Config, + dashboard: Arc, + start_time: Instant, +} + +impl ExecutionContext { + fn new(tasks: Vec, config: Config) -> Self { + let dashboard = Arc::new(TaskDashboard::new(tasks.clone())); + + Self { + tasks, + config, + dashboard, + start_time: Instant::now(), + } + } - // Create channels + fn task_count(&self) -> usize { + self.tasks.len() + } +} + +fn create_channels( + task_count: usize, +) -> ( + mpsc::Sender, + mpsc::Receiver, + mpsc::Sender, + mpsc::Receiver, +) { let (task_tx, task_rx) = mpsc::channel::(task_count); - let (result_tx, mut result_rx) = mpsc::channel::(task_count); + let (result_tx, result_rx) = mpsc::channel::(task_count); + (task_tx, task_rx, result_tx, result_rx) +} - // Initialize shared state - let shared_state = Arc::new(SharedState { +fn create_shared_state( + task_rx: mpsc::Receiver, + result_tx: mpsc::Sender, + dashboard: Arc, +) -> Arc { + Arc::new(SharedState { task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), result_sender: result_tx, active_workers: Arc::new(AtomicUsize::new(0)), - should_stop: Arc::new(AtomicBool::new(false)), - completed_tasks: Arc::new(AtomicUsize::new(0)), - }); + dashboard: Some(dashboard), + }) +} - // Send all tasks to the queue - for task in tasks.clone() { - let _ = task_tx.send(task).await; +async fn send_tasks_to_channel( + tasks: Vec, + task_tx: mpsc::Sender, +) -> Result<(), String> { + for task in tasks { + task_tx + .send(task) + .await + .map_err(|e| format!("Failed to queue task: {}", e))?; } - // Close sender so workers know when queue is empty - drop(task_tx); + Ok(()) +} - // Start initial workers - let mut worker_handles = Vec::new(); - for i in 0..config.initial_workers { - let handle = spawn_worker(shared_state.clone(), i, config.timeout_seconds); - worker_handles.push(handle); +fn create_empty_response() -> ExecutionResponse { + ExecutionResponse { + status: EXECUTION_STATUS_COMPLETED.to_string(), + results: vec![], + stats: ExecutionStats { + total_tasks: 0, + completed: 0, + failed: 0, + execution_time_ms: 0, + }, } +} - // Start the scaler - let scaler_state = shared_state.clone(); - let scaler_handle = tokio::spawn(async move { - run_scaler( - scaler_state, - task_count, - config.max_workers, - config.timeout_seconds, - ) - .await; - }); - - // Collect results +async fn collect_results( + result_rx: &mut mpsc::Receiver, + dashboard: Arc, + expected_count: usize, +) -> Vec { let mut results = Vec::new(); while let Some(result) = result_rx.recv().await { + dashboard + .complete_task(&result.task_id, result.clone()) + .await; results.push(result); - if results.len() >= task_count { + if results.len() >= expected_count { break; } } + results +} - // Wait for scaler to finish - let _ = scaler_handle.await; +async fn execute_with_context(ctx: ExecutionContext) -> Result { + let task_count = ctx.task_count(); - // Calculate stats - let execution_time = start_time.elapsed().as_millis(); - let completed = results.iter().filter(|r| r.status == "success").count(); - let failed = results.iter().filter(|r| r.status == "failed").count(); + if task_count == 0 { + return Ok(create_empty_response()); + } - ExecutionResponse { - status: "completed".to_string(), + ctx.dashboard.refresh_display().await; + + let (task_tx, task_rx, result_tx, mut result_rx) = create_channels(task_count); + + send_tasks_to_channel(ctx.tasks, task_tx).await?; + + let shared_state = create_shared_state(task_rx, result_tx, ctx.dashboard.clone()); + + // Simple static worker allocation - no dynamic scaling needed + let worker_count = std::cmp::min(task_count, ctx.config.max_workers); + let mut worker_handles = Vec::new(); + for i in 0..worker_count { + let handle = spawn_worker(shared_state.clone(), i, ctx.config.timeout_seconds); + worker_handles.push(handle); + } + + let results = collect_results(&mut result_rx, ctx.dashboard.clone(), task_count).await; + + // Wait for all workers to finish + for handle in worker_handles { + if let Err(e) = handle.await { + eprintln!("Worker error: {}", e); + } + } + + ctx.dashboard.show_final_summary().await; + + let execution_time = ctx.start_time.elapsed().as_millis(); + let stats = calculate_stats(&results, execution_time); + + Ok(ExecutionResponse { + status: EXECUTION_STATUS_COMPLETED.to_string(), results, + stats, + }) +} + +fn create_error_response(_error: String) -> ExecutionResponse { + ExecutionResponse { + status: "failed".to_string(), + results: vec![], stats: ExecutionStats { - total_tasks: task_count, - completed, - failed, - execution_time_ms: execution_time, + total_tasks: 0, + completed: 0, + failed: 1, + execution_time_ms: 0, }, } } + +pub async fn parallel_execute(tasks: Vec, config: Config) -> ExecutionResponse { + let ctx = ExecutionContext::new(tasks, config); + + match execute_with_context(ctx).await { + Ok(response) => response, + Err(e) => { + eprintln!("Execution failed: {}", e); + create_error_response(e) + } + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs index 9df784a46be0..6c4718c2b3a0 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs @@ -1,7 +1,7 @@ use crate::agents::sub_recipe_execution_tool::executor::execute_single_task; pub use crate::agents::sub_recipe_execution_tool::executor::parallel_execute; pub use crate::agents::sub_recipe_execution_tool::types::{ - Config, ExecutionResponse, ExecutionStats, Task, TaskResult, + Config, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; use serde_json::Value; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs index a49791e2776f..6f131862fefd 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs @@ -1,6 +1,8 @@ +mod dashboard; mod executor; pub mod lib; pub mod sub_recipe_execute_task_tool; mod tasks; mod types; +pub mod utils; mod workers; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index b934a13344fc..9543f0cec504 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -1,42 +1,53 @@ use serde_json::Value; use std::process::Stdio; +use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; use tokio::time::timeout; -use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult}; +use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; +use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult, TaskStatus}; pub async fn process_task(task: &Task, timeout_seconds: u64) -> TaskResult { + process_task_with_dashboard(task, timeout_seconds, None).await +} + +pub async fn process_task_with_dashboard( + task: &Task, + timeout_seconds: u64, + dashboard: Option>, +) -> TaskResult { let task_clone = task.clone(); let timeout_duration = Duration::from_secs(timeout_seconds); - match timeout(timeout_duration, execute_task(task_clone)).await { + match timeout(timeout_duration, execute_task(task_clone, dashboard)).await { Ok(Ok(data)) => TaskResult { task_id: task.id.clone(), - status: "success".to_string(), + status: TaskStatus::Completed, data: Some(data), error: None, }, Ok(Err(error)) => TaskResult { task_id: task.id.clone(), - status: "failed".to_string(), + status: TaskStatus::Failed, data: None, error: Some(error), }, Err(_) => TaskResult { task_id: task.id.clone(), - status: "failed".to_string(), + status: TaskStatus::Failed, data: None, error: Some("Task timeout".to_string()), }, } } -async fn execute_task(task: Task) -> Result { +async fn execute_task(task: Task, dashboard: Option>) -> Result { let (command, output_identifier) = build_command(&task)?; - let (stdout_output, stderr_output, success) = run_command(command, &output_identifier).await?; - + let (stdout_output, stderr_output, success) = + run_command(command, &output_identifier, &task.id, dashboard).await?; + if success { process_output(stdout_output) } else { @@ -80,7 +91,12 @@ fn build_command(task: &Task) -> Result<(Command, String), String> { Ok((command, output_identifier)) } -async fn run_command(mut command: Command, output_identifier: &str) -> Result<(String, String, bool), String> { +async fn run_command( + mut command: Command, + output_identifier: &str, + task_id: &str, + dashboard: Option>, +) -> Result<(String, String, bool), String> { let mut child = command .spawn() .map_err(|e| format!("Failed to spawn goose: {}", e))?; @@ -88,8 +104,9 @@ async fn run_command(mut command: Command, output_identifier: &str) -> Result<(S let stdout = child.stdout.take().expect("Failed to capture stdout"); let stderr = child.stderr.take().expect("Failed to capture stderr"); - let stdout_task = spawn_output_reader(stdout, output_identifier, false); - let stderr_task = spawn_output_reader(stderr, output_identifier, true); + let stdout_task = + spawn_output_reader(stdout, output_identifier, false, task_id, dashboard.clone()); + let stderr_task = spawn_output_reader(stderr, output_identifier, true, task_id, None); let status = child .wait() @@ -106,19 +123,25 @@ fn spawn_output_reader( reader: impl tokio::io::AsyncRead + Unpin + Send + 'static, output_identifier: &str, is_stderr: bool, + task_id: &str, + dashboard: Option>, ) -> tokio::task::JoinHandle { let output_identifier = output_identifier.to_string(); + let task_id = task_id.to_string(); tokio::spawn(async move { let mut buffer = String::new(); let mut lines = BufReader::new(reader).lines(); while let Ok(Some(line)) = lines.next_line().await { - if is_stderr { - eprintln!("[stderr for {}] {}", output_identifier, line); - } else { - println!("[{}] {}", output_identifier, line); - } buffer.push_str(&line); buffer.push('\n'); + + if !is_stderr { + if let Some(dashboard) = &dashboard { + dashboard.update_task_output(&task_id, &buffer).await; + } + } else { + eprintln!("[stderr for {}] {}", output_identifier, line); + } } buffer }) @@ -128,13 +151,13 @@ fn process_output(stdout_output: String) -> Result { let last_line = stdout_output .lines() .filter(|line| !line.trim().is_empty()) - .last() + .next_back() .unwrap_or(""); - + if let (Some(start), Some(end)) = (last_line.find('{'), last_line.rfind('}')) { if start < end { let potential_json = &last_line[start..=end]; - + if serde_json::from_str::(potential_json).is_ok() { return Ok(Value::String(potential_json.to_string())); } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index ede71dbf40b4..2f40490a564c 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -1,7 +1,11 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::mpsc; + +use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; -// Task definition that LLMs will send #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Task { pub id: String, @@ -9,18 +13,61 @@ pub struct Task { pub payload: Value, } -// Result for each task #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TaskResult { pub task_id: String, - pub status: String, + pub status: TaskStatus, #[serde(skip_serializing_if = "Option::is_none")] pub data: Option, #[serde(skip_serializing_if = "Option::is_none")] pub error: Option, } -// Configuration for the parallel executor +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum TaskStatus { + Pending, + Running, + Completed, + Failed, +} + +#[derive(Debug, Clone)] +pub struct TaskInfo { + pub task: Task, + pub status: TaskStatus, + pub start_time: Option, + pub end_time: Option, + pub result: Option, + pub current_output: String, +} + +impl TaskInfo { + pub fn error(&self) -> Option<&String> { + self.result.as_ref().and_then(|r| r.error.as_ref()) + } + + pub fn data(&self) -> Option<&Value> { + self.result.as_ref().and_then(|r| r.data.as_ref()) + } +} + +pub struct SharedState { + pub task_receiver: Arc>>, + pub result_sender: mpsc::Sender, + pub active_workers: Arc, + pub dashboard: Option>, +} + +impl SharedState { + pub fn increment_active_workers(&self) { + self.active_workers.fetch_add(1, Ordering::SeqCst); + } + + pub fn decrement_active_workers(&self) { + self.active_workers.fetch_sub(1, Ordering::SeqCst); + } +} + #[derive(Debug, Clone, Deserialize)] pub struct Config { #[serde(default = "default_max_workers")] @@ -51,7 +98,6 @@ fn default_initial_workers() -> usize { 2 } -// Stats for the execution #[derive(Debug, Serialize)] pub struct ExecutionStats { pub total_tasks: usize, @@ -60,7 +106,6 @@ pub struct ExecutionStats { pub execution_time_ms: u128, } -// Main response structure #[derive(Debug, Serialize)] pub struct ExecutionResponse { pub status: String, diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs new file mode 100644 index 000000000000..ebd12e7f11e3 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -0,0 +1,65 @@ +use std::collections::HashMap; + +use crate::agents::sub_recipe_execution_tool::types::{TaskInfo, TaskStatus}; + +pub fn get_task_name(task_info: &TaskInfo) -> &str { + if task_info.task.task_type == "sub_recipe" { + task_info + .task + .payload + .get("sub_recipe") + .and_then(|sr| sr.get("name")) + .and_then(|n| n.as_str()) + .unwrap_or(&task_info.task.id) + } else { + &task_info.task.id + } +} + +pub fn truncate_with_ellipsis(text: &str, max_len: usize) -> String { + if text.len() > max_len { + format!("{}...", &text[..max_len.saturating_sub(3)]) + } else { + text.to_string() + } +} + +pub fn count_by_status(tasks: &HashMap) -> (usize, usize, usize, usize, usize) { + let total = tasks.len(); + let (pending, running, completed, failed) = tasks.values().fold( + (0, 0, 0, 0), + |(pending, running, completed, failed), task| match task.status { + TaskStatus::Pending => (pending + 1, running, completed, failed), + TaskStatus::Running => (pending, running + 1, completed, failed), + TaskStatus::Completed => (pending, running, completed + 1, failed), + TaskStatus::Failed => (pending, running, completed, failed + 1), + }, + ); + (total, pending, running, completed, failed) +} + +pub fn strip_ansi_codes(text: &str) -> String { + let mut result = String::new(); + let mut chars = text.chars(); + + while let Some(ch) = chars.next() { + if ch == '\x1b' { + if chars.next() == Some('[') { + loop { + match chars.next() { + Some(c) if c.is_ascii_alphabetic() => break, + Some(_) => continue, + None => break, + } + } + } + } else { + result.push(ch); + } + } + + result +} + +#[cfg(test)] +mod tests; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs new file mode 100644 index 000000000000..a2e1d9b74a8f --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -0,0 +1,262 @@ +#[cfg(test)] +mod tests { + use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskStatus}; + use crate::agents::sub_recipe_execution_tool::utils::{ + count_by_status, get_task_name, strip_ansi_codes, truncate_with_ellipsis, + }; + use serde_json::json; + use std::collections::HashMap; + + mod truncate_with_ellipsis { + use super::*; + + #[test] + fn returns_original_when_under_limit() { + assert_eq!(truncate_with_ellipsis("hello", 10), "hello"); + assert_eq!(truncate_with_ellipsis("hi", 5), "hi"); + } + + #[test] + fn truncates_when_over_limit() { + assert_eq!(truncate_with_ellipsis("hello world", 5), "he..."); + assert_eq!( + truncate_with_ellipsis("very long text here", 10), + "very lo..." + ); + } + + #[test] + fn handles_empty_string() { + assert_eq!(truncate_with_ellipsis("", 5), ""); + } + + #[test] + fn handles_exact_limit() { + assert_eq!(truncate_with_ellipsis("hello", 5), "hello"); + } + + #[test] + fn handles_very_short_limit() { + assert_eq!(truncate_with_ellipsis("hello", 3), "..."); + assert_eq!(truncate_with_ellipsis("hi", 2), "hi"); // Under limit, return as-is + } + } + + mod strip_ansi_codes { + use super::*; + + #[test] + fn preserves_plain_text() { + assert_eq!(strip_ansi_codes("hello world"), "hello world"); + assert_eq!(strip_ansi_codes("no ansi codes"), "no ansi codes"); + } + + #[test] + fn removes_color_codes() { + assert_eq!(strip_ansi_codes("\x1b[31mred text\x1b[0m"), "red text"); + assert_eq!(strip_ansi_codes("\x1b[32mgreen\x1b[0m"), "green"); + } + + #[test] + fn removes_complex_formatting() { + assert_eq!( + strip_ansi_codes("\x1b[1;32mbold green\x1b[0m"), + "bold green" + ); + assert_eq!( + strip_ansi_codes("\x1b[4;31munderline red\x1b[0m"), + "underline red" + ); + } + + #[test] + fn handles_multiple_sequences() { + let input = "\x1b[31mred\x1b[0m normal \x1b[32mgreen\x1b[0m"; + assert_eq!(strip_ansi_codes(input), "red normal green"); + } + + #[test] + fn handles_empty_string() { + assert_eq!(strip_ansi_codes(""), ""); + } + + #[test] + fn handles_malformed_sequences() { + // Incomplete escape sequence - our function strips the \x1b char + assert_eq!(strip_ansi_codes("\x1b hello"), "hello"); + } + } + + mod get_task_name { + use super::*; + + #[test] + fn extracts_sub_recipe_name() { + let sub_recipe_task = Task { + id: "task_1".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({ + "sub_recipe": { + "name": "my_recipe", + "recipe_path": "/path/to/recipe" + } + }), + }; + + let task_info = TaskInfo { + task: sub_recipe_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "my_recipe"); + } + + #[test] + fn falls_back_to_task_id_for_text_instruction() { + let text_task = Task { + id: "task_2".to_string(), + task_type: "text_instruction".to_string(), + payload: json!({"text_instruction": "do something"}), + }; + + let task_info = TaskInfo { + task: text_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "task_2"); + } + + #[test] + fn falls_back_to_task_id_when_sub_recipe_name_missing() { + let malformed_task = Task { + id: "task_3".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({ + "sub_recipe": { + "recipe_path": "/path/to/recipe" + // missing "name" field + } + }), + }; + + let task_info = TaskInfo { + task: malformed_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "task_3"); + } + + #[test] + fn falls_back_to_task_id_when_sub_recipe_missing() { + let malformed_task = Task { + id: "task_4".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({}), // missing "sub_recipe" field + }; + + let task_info = TaskInfo { + task: malformed_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "task_4"); + } + } + + mod count_by_status { + use super::*; + + fn create_test_task(id: &str, status: TaskStatus) -> TaskInfo { + TaskInfo { + task: Task { + id: id.to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + } + } + + #[test] + fn counts_empty_map() { + let tasks = HashMap::new(); + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (0, 0, 0, 0, 0) + ); + } + + #[test] + fn counts_single_status() { + let mut tasks = HashMap::new(); + tasks.insert( + "task1".to_string(), + create_test_task("task1", TaskStatus::Pending), + ); + tasks.insert( + "task2".to_string(), + create_test_task("task2", TaskStatus::Pending), + ); + + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (2, 2, 0, 0, 0) + ); + } + + #[test] + fn counts_mixed_statuses() { + let mut tasks = HashMap::new(); + tasks.insert( + "task1".to_string(), + create_test_task("task1", TaskStatus::Pending), + ); + tasks.insert( + "task2".to_string(), + create_test_task("task2", TaskStatus::Running), + ); + tasks.insert( + "task3".to_string(), + create_test_task("task3", TaskStatus::Completed), + ); + tasks.insert( + "task4".to_string(), + create_test_task("task4", TaskStatus::Failed), + ); + tasks.insert( + "task5".to_string(), + create_test_task("task5", TaskStatus::Completed), + ); + + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (5, 1, 1, 2, 1) + ); + } + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index e48f19c4d360..d9f750480432 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -1,9 +1,7 @@ -use crate::agents::sub_recipe_execution_tool::tasks::process_task; -use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult}; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; +use crate::agents::sub_recipe_execution_tool::tasks::{process_task, process_task_with_dashboard}; +use crate::agents::sub_recipe_execution_tool::types::{SharedState, Task, TaskResult}; use std::sync::Arc; -use tokio::sync::mpsc; -use tokio::time::{sleep, Duration}; #[cfg(test)] mod tests { @@ -22,6 +20,7 @@ mod tests { active_workers: Arc::new(AtomicUsize::new(0)), should_stop: Arc::new(AtomicBool::new(false)), completed_tasks: Arc::new(AtomicUsize::new(0)), + dashboard: None, }); // Test that spawn_worker returns a JoinHandle @@ -40,21 +39,33 @@ mod tests { } } -pub struct SharedState { - pub task_receiver: Arc>>, - pub result_sender: mpsc::Sender, - pub active_workers: Arc, - pub should_stop: Arc, - pub completed_tasks: Arc, +async fn receive_task(state: &SharedState) -> Option { + let mut receiver = state.task_receiver.lock().await; + receiver.recv().await +} + +async fn execute_task( + task: Task, + timeout: u64, + dashboard: Option>, +) -> TaskResult { + if let Some(dashboard) = &dashboard { + dashboard.start_task(&task.id).await; + } + + if let Some(dashboard) = dashboard { + process_task_with_dashboard(&task, timeout, Some(dashboard)).await + } else { + process_task(&task, timeout).await + } } -// Spawn a worker task pub fn spawn_worker( state: Arc, worker_id: usize, timeout_seconds: u64, ) -> tokio::task::JoinHandle<()> { - state.active_workers.fetch_add(1, Ordering::SeqCst); + state.increment_active_workers(); tokio::spawn(async move { worker_loop(state, worker_id, timeout_seconds).await; @@ -62,72 +73,14 @@ pub fn spawn_worker( } async fn worker_loop(state: Arc, _worker_id: usize, timeout_seconds: u64) { - loop { - // Try to receive a task - let task = { - let mut receiver = state.task_receiver.lock().await; - receiver.recv().await - }; - - match task { - Some(task) => { - // Process the task - let result = process_task(&task, timeout_seconds).await; + while let Some(task) = receive_task(&state).await { + let result = execute_task(task, timeout_seconds, state.dashboard.clone()).await; - // Send result - let _ = state.result_sender.send(result).await; - - // Update completed count - state.completed_tasks.fetch_add(1, Ordering::SeqCst); - } - None => { - // Channel closed, exit worker - break; - } - } - - // Check if we should stop - if state.should_stop.load(Ordering::SeqCst) { + if let Err(e) = state.result_sender.send(result).await { + eprintln!("Worker failed to send result: {}", e); break; } } - // Worker is exiting - state.active_workers.fetch_sub(1, Ordering::SeqCst); -} - -// Scaling controller that monitors queue and spawns workers -pub async fn run_scaler( - state: Arc, - task_count: usize, - max_workers: usize, - timeout_seconds: u64, -) { - let mut worker_count = 0; - - loop { - sleep(Duration::from_millis(100)).await; - - let active = state.active_workers.load(Ordering::SeqCst); - let completed = state.completed_tasks.load(Ordering::SeqCst); - let pending = task_count.saturating_sub(completed); - - // Simple scaling logic: spawn worker if many pending tasks and under limit - if pending > active * 2 && active < max_workers && worker_count < max_workers { - let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds); - worker_count += 1; - } - - // If all tasks completed, signal stop - if completed >= task_count { - state.should_stop.store(true, Ordering::SeqCst); - break; - } - - // If no active workers and tasks remaining, spawn one - if active == 0 && pending > 0 { - let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds); - worker_count += 1; - } - } + state.decrement_active_workers(); } From 9ca6b25e6734d343b87c26b3827ab2c7ed7821bc Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 8 Jul 2025 17:21:20 +1000 Subject: [PATCH 05/34] better view --- .../sub_recipe_execution_tool/dashboard.rs | 47 +++++++++++++------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs index 68867bb8e5c8..c74cb7024187 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs @@ -13,6 +13,7 @@ pub struct TaskDashboard { tasks: Arc>>, last_display: Arc>, last_refresh: Arc>, + initial_display_shown: Arc>, } impl TaskDashboard { @@ -39,6 +40,7 @@ impl TaskDashboard { tasks: Arc::new(RwLock::new(task_map)), last_display: Arc::new(RwLock::new(String::new())), last_refresh: Arc::new(RwLock::new(Instant::now())), + initial_display_shown: Arc::new(RwLock::new(false)), } } @@ -94,20 +96,28 @@ impl TaskDashboard { let tasks = self.tasks.read().await; let mut display = String::new(); - // Clear screen and move to top - display.push_str("\x1b[2J\x1b[H"); - - // Title - display.push_str("🎯 Task Execution Dashboard\n"); - display.push_str("═══════════════════════════\n\n"); + let mut initial_shown = self.initial_display_shown.write().await; + if !*initial_shown { + // Clear screen and show header only on first display + display.push_str("\x1b[2J\x1b[H"); + display.push_str("🎯 Task Execution Dashboard\n"); + display.push_str("═══════════════════════════\n\n"); + *initial_shown = true; + } else { + // Move cursor to beginning of progress line (line 4) + display.push_str("\x1b[4;1H"); + } + drop(initial_shown); - // Summary stats + // Summary stats (this line gets updated in-place) let (total, pending, running, completed, failed) = count_by_status(&tasks); - - display.push_str(&format!("📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed\n\n", + display.push_str(&format!("📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed", total, pending, running, completed, failed)); - // Task list + // Clear to end of line and add newlines + display.push_str("\x1b[K\n\n"); + + // Task list (update in-place) let mut task_list: Vec<_> = tasks.values().collect(); task_list.sort_by_key(|t| &t.task.id); @@ -122,9 +132,10 @@ impl TaskDashboard { let task_name = get_task_name(task_info); display.push_str(&format!( - "{} {} ({})\n", + "{} {} ({})", status_icon, task_name, task_info.task.task_type )); + display.push_str("\x1b[K\n"); // Clear to end of line if let Some(start_time) = task_info.start_time { let duration = if let Some(end_time) = task_info.end_time { @@ -132,24 +143,30 @@ impl TaskDashboard { } else { Instant::now().duration_since(start_time) }; - display.push_str(&format!(" ⏱️ {:.1}s\n", duration.as_secs_f64())); + display.push_str(&format!(" ⏱️ {:.1}s", duration.as_secs_f64())); + display.push_str("\x1b[K\n"); // Clear to end of line } if matches!(task_info.status, TaskStatus::Running) && !task_info.current_output.is_empty() { let output_preview = truncate_with_ellipsis(&task_info.current_output, 100); - display.push_str(&format!(" 💬 {}\n", output_preview.replace('\n', " | "))); + display.push_str(&format!(" 💬 {}", output_preview.replace('\n', " | "))); + display.push_str("\x1b[K\n"); // Clear to end of line } if let Some(error) = task_info.error() { let error_preview = truncate_with_ellipsis(error, 80); - display.push_str(&format!(" ⚠️ {}\n", error_preview.replace('\n', " "))); + display.push_str(&format!(" ⚠️ {}", error_preview.replace('\n', " "))); + display.push_str("\x1b[K\n"); // Clear to end of line } - display.push('\n'); + display.push_str("\x1b[K\n"); // Clear to end of line and add blank line } + // Clear any remaining lines below + display.push_str("\x1b[J"); + // Only update display if it changed let mut last_display = self.last_display.write().await; if *last_display != display { From 744669280d5485060fd8f889b0413a8b326de8b1 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 8 Jul 2025 18:04:45 +1000 Subject: [PATCH 06/34] clean up dashboard.rs --- .../sub_recipe_execution_tool/dashboard.rs | 155 ++++++++---------- .../sub_recipe_execution_tool/utils/mod.rs | 117 +++++++++++++ 2 files changed, 183 insertions(+), 89 deletions(-) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs index c74cb7024187..ee8ae16bd4b5 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs @@ -2,13 +2,19 @@ use std::collections::HashMap; use std::io::{self, Write}; use std::sync::Arc; use tokio::sync::RwLock; -use tokio::time::{Duration, Instant}; +use tokio::time::{sleep, Duration, Instant}; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskResult, TaskStatus}; use crate::agents::sub_recipe_execution_tool::utils::{ - count_by_status, get_task_name, strip_ansi_codes, truncate_with_ellipsis, + count_by_status, format_task_display, get_task_name, process_output_lines, }; +const THROTTLE_INTERVAL_MS: u64 = 1000; +const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H"; +const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H"; +const CLEAR_TO_EOL: &str = "\x1b[K"; +const CLEAR_BELOW: &str = "\x1b[J"; + pub struct TaskDashboard { tasks: Arc>>, last_display: Arc>, @@ -68,128 +74,96 @@ impl TaskDashboard { pub async fn update_task_output(&self, task_id: &str, output: &str) { let mut tasks = self.tasks.write().await; if let Some(task_info) = tasks.get_mut(task_id) { - // Keep only the last few lines to avoid overwhelming display - let lines: Vec<&str> = output.lines().collect(); - let recent_lines = if lines.len() > 2 { - &lines[lines.len() - 2..] - } else { - &lines - }; - - // Strip ANSI escape sequences to prevent color flashing - let clean_output = recent_lines.join("\n"); - task_info.current_output = strip_ansi_codes(&clean_output); + task_info.current_output = process_output_lines(output); } drop(tasks); - // Throttle refreshes to avoid overwhelming the display (max 1 per second) + if !self.should_throttle_refresh().await { + self.refresh_display().await; + } + } + + async fn should_throttle_refresh(&self) -> bool { let now = Instant::now(); let mut last_refresh = self.last_refresh.write().await; - if now.duration_since(*last_refresh) > Duration::from_millis(1000) { + + if now.duration_since(*last_refresh) > Duration::from_millis(THROTTLE_INTERVAL_MS) { *last_refresh = now; - drop(last_refresh); - self.refresh_display().await; + false + } else { + true } } - pub async fn refresh_display(&self) { - let tasks = self.tasks.read().await; - let mut display = String::new(); - - let mut initial_shown = self.initial_display_shown.write().await; + fn render_header(&self, display: &mut String, initial_shown: &mut bool) { if !*initial_shown { - // Clear screen and show header only on first display - display.push_str("\x1b[2J\x1b[H"); + display.push_str(CLEAR_SCREEN); display.push_str("🎯 Task Execution Dashboard\n"); display.push_str("═══════════════════════════\n\n"); *initial_shown = true; } else { - // Move cursor to beginning of progress line (line 4) - display.push_str("\x1b[4;1H"); + display.push_str(MOVE_TO_PROGRESS_LINE); } - drop(initial_shown); + } - // Summary stats (this line gets updated in-place) - let (total, pending, running, completed, failed) = count_by_status(&tasks); - display.push_str(&format!("📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed", - total, pending, running, completed, failed)); + fn render_progress_line(&self, display: &mut String, tasks: &HashMap) { + let (total, pending, running, completed, failed) = count_by_status(tasks); + display.push_str(&format!( + "📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed", + total, pending, running, completed, failed + )); + display.push_str(&format!("{}\n\n", CLEAR_TO_EOL)); + } - // Clear to end of line and add newlines - display.push_str("\x1b[K\n\n"); + fn render_task(&self, display: &mut String, task_info: &TaskInfo) { + let task_display = format_task_display(task_info, Instant::now()); + display.push_str(&task_display); + } - // Task list (update in-place) - let mut task_list: Vec<_> = tasks.values().collect(); - task_list.sort_by_key(|t| &t.task.id); + async fn update_display_if_changed(&self, display: String) { + let mut last_display = self.last_display.write().await; + if *last_display != display { + print!("{}", display); + io::stdout().flush().unwrap(); + *last_display = display; + } + } - for task_info in task_list { - let status_icon = match task_info.status { - TaskStatus::Pending => "⏳", - TaskStatus::Running => "🏃", - TaskStatus::Completed => "✅", - TaskStatus::Failed => "❌", - }; - - let task_name = get_task_name(task_info); - - display.push_str(&format!( - "{} {} ({})", - status_icon, task_name, task_info.task.task_type - )); - display.push_str("\x1b[K\n"); // Clear to end of line - - if let Some(start_time) = task_info.start_time { - let duration = if let Some(end_time) = task_info.end_time { - end_time.duration_since(start_time) - } else { - Instant::now().duration_since(start_time) - }; - display.push_str(&format!(" ⏱️ {:.1}s", duration.as_secs_f64())); - display.push_str("\x1b[K\n"); // Clear to end of line - } + pub async fn refresh_display(&self) { + let tasks = self.tasks.read().await; + let mut display = String::new(); - if matches!(task_info.status, TaskStatus::Running) - && !task_info.current_output.is_empty() - { - let output_preview = truncate_with_ellipsis(&task_info.current_output, 100); - display.push_str(&format!(" 💬 {}", output_preview.replace('\n', " | "))); - display.push_str("\x1b[K\n"); // Clear to end of line - } + let mut initial_shown = self.initial_display_shown.write().await; + self.render_header(&mut display, &mut initial_shown); + drop(initial_shown); - if let Some(error) = task_info.error() { - let error_preview = truncate_with_ellipsis(error, 80); - display.push_str(&format!(" ⚠️ {}", error_preview.replace('\n', " "))); - display.push_str("\x1b[K\n"); // Clear to end of line - } + self.render_progress_line(&mut display, &tasks); - display.push_str("\x1b[K\n"); // Clear to end of line and add blank line + let mut task_list: Vec<_> = tasks.values().collect(); + task_list.sort_by_key(|t| &t.task.id); + + for task_info in task_list { + self.render_task(&mut display, task_info); } - // Clear any remaining lines below - display.push_str("\x1b[J"); + display.push_str(CLEAR_BELOW); - // Only update display if it changed - let mut last_display = self.last_display.write().await; - if *last_display != display { - print!("{}", display); - io::stdout().flush().unwrap(); - *last_display = display; - } + self.update_display_if_changed(display).await; } pub async fn show_final_summary(&self) { let tasks = self.tasks.read().await; - println!("\n🎉 Execution Complete!"); + println!("Execution Complete!"); println!("═══════════════════════"); let (total, _, _, completed, failed) = count_by_status(&tasks); - println!("📊 Final Results:"); - println!(" Total Tasks: {}", total); - println!(" ✅ Completed: {}", completed); - println!(" ❌ Failed: {}", failed); + println!("Total Tasks: {}", total); + println!("✅ Completed: {}", completed); + println!("❌ Failed: {}", failed); println!( - " 📈 Success Rate: {:.1}%", + "📈 Success Rate: {:.1}%", (completed as f64 / total as f64) * 100.0 ); @@ -205,5 +179,8 @@ impl TaskDashboard { } } } + + println!("\n📝 Generating summary..."); + sleep(Duration::from_millis(500)).await; } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index ebd12e7f11e3..31f20273ec9e 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -1,7 +1,14 @@ use std::collections::HashMap; +use tokio::time::Instant; use crate::agents::sub_recipe_execution_tool::types::{TaskInfo, TaskStatus}; +// Constants for display formatting +const MAX_OUTPUT_LINES: usize = 2; +const OUTPUT_PREVIEW_LENGTH: usize = 100; +const ERROR_PREVIEW_LENGTH: usize = 80; +const CLEAR_TO_EOL: &str = "\x1b[K"; + pub fn get_task_name(task_info: &TaskInfo) -> &str { if task_info.task.task_type == "sub_recipe" { task_info @@ -61,5 +68,115 @@ pub fn strip_ansi_codes(text: &str) -> String { result } +// Pure utility functions for dashboard rendering + +/// Get status icon for a given task status +pub fn get_status_icon(status: &TaskStatus) -> &'static str { + match status { + TaskStatus::Pending => "⏳", + TaskStatus::Running => "🏃", + TaskStatus::Completed => "✅", + TaskStatus::Failed => "❌", + } +} + +/// Process output lines, keeping only recent lines and stripping ANSI codes +pub fn process_output_lines(output: &str) -> String { + let lines: Vec<&str> = output.lines().collect(); + let recent_lines = if lines.len() > MAX_OUTPUT_LINES { + &lines[lines.len() - MAX_OUTPUT_LINES..] + } else { + &lines + }; + + let clean_output = recent_lines.join("\n"); + strip_ansi_codes(&clean_output) +} + +/// Format task timing information +pub fn format_task_timing(task_info: &TaskInfo, current_time: Instant) -> Option { + task_info.start_time.map(|start_time| { + let duration = if let Some(end_time) = task_info.end_time { + end_time.duration_since(start_time) + } else { + current_time.duration_since(start_time) + }; + format!( + " ⏱️ {:.1}s{} +", + duration.as_secs_f64(), + CLEAR_TO_EOL + ) + }) +} + +/// Format task output preview +pub fn format_task_output(task_info: &TaskInfo) -> Option { + if matches!(task_info.status, TaskStatus::Running) && !task_info.current_output.is_empty() { + let output_preview = + truncate_with_ellipsis(&task_info.current_output, OUTPUT_PREVIEW_LENGTH); + Some(format!( + " 💬 {}{} +", + output_preview.replace('\n', " | "), + CLEAR_TO_EOL + )) + } else { + None + } +} + +/// Format task error information +pub fn format_task_error(task_info: &TaskInfo) -> Option { + task_info.error().map(|error| { + let error_preview = truncate_with_ellipsis(error, ERROR_PREVIEW_LENGTH); + format!( + " ⚠️ {}{} +", + error_preview.replace('\n', " "), + CLEAR_TO_EOL + ) + }) +} + +/// Format complete task display +pub fn format_task_display(task_info: &TaskInfo, current_time: Instant) -> String { + let mut display = String::new(); + + let status_icon = get_status_icon(&task_info.status); + let task_name = get_task_name(task_info); + + // Task status line + display.push_str(&format!( + "{} {} ({}){} +", + status_icon, task_name, task_info.task.task_type, CLEAR_TO_EOL + )); + + // Task timing + if let Some(timing) = format_task_timing(task_info, current_time) { + display.push_str(&timing); + } + + // Task output (if running) + if let Some(output) = format_task_output(task_info) { + display.push_str(&output); + } + + // Task error (if failed) + if let Some(error) = format_task_error(task_info) { + display.push_str(&error); + } + + // Empty line + display.push_str(&format!( + "{} +", + CLEAR_TO_EOL + )); + + display +} + #[cfg(test)] mod tests; From b84840913bbc92eb29c729fe3723b2f86450ba6b Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 8 Jul 2025 18:13:29 +1000 Subject: [PATCH 07/34] fixed tests --- .../sub_recipe_execution_tool/utils/tests.rs | 381 +++++++++++++++++- 1 file changed, 369 insertions(+), 12 deletions(-) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs index a2e1d9b74a8f..b0d3e5e0dac6 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -1,14 +1,15 @@ -#[cfg(test)] -mod tests { - use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskStatus}; - use crate::agents::sub_recipe_execution_tool::utils::{ - count_by_status, get_task_name, strip_ansi_codes, truncate_with_ellipsis, - }; - use serde_json::json; - use std::collections::HashMap; - - mod truncate_with_ellipsis { - use super::*; +use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskStatus}; +use crate::agents::sub_recipe_execution_tool::utils::{ + count_by_status, format_task_display, format_task_error, format_task_output, + format_task_timing, get_status_icon, get_task_name, process_output_lines, strip_ansi_codes, + truncate_with_ellipsis, +}; +use serde_json::json; +use std::collections::HashMap; +use tokio::time::Instant; + +mod truncate_with_ellipsis { + use super::*; #[test] fn returns_original_when_under_limit() { @@ -259,4 +260,360 @@ mod tests { ); } } -} + + mod get_status_icon { + use super::*; + + #[test] + fn returns_correct_icon_for_pending() { + assert_eq!(get_status_icon(&TaskStatus::Pending), "⏳"); + } + + #[test] + fn returns_correct_icon_for_running() { + assert_eq!(get_status_icon(&TaskStatus::Running), "🏃"); + } + + #[test] + fn returns_correct_icon_for_completed() { + assert_eq!(get_status_icon(&TaskStatus::Completed), "✅"); + } + + #[test] + fn returns_correct_icon_for_failed() { + assert_eq!(get_status_icon(&TaskStatus::Failed), "❌"); + } + } + + mod process_output_lines { + use super::*; + + #[test] + fn preserves_short_output() { + let output = "line 1\nline 2"; + assert_eq!(process_output_lines(output), "line 1\nline 2"); + } + + #[test] + fn keeps_only_recent_lines_when_too_many() { + let output = "line 1\nline 2\nline 3\nline 4\nline 5"; + let result = process_output_lines(output); + assert_eq!(result, "line 4\nline 5"); + } + + #[test] + fn strips_ansi_codes_from_output() { + let output = "\x1b[31mred line 1\x1b[0m\n\x1b[32mgreen line 2\x1b[0m"; + let result = process_output_lines(output); + assert_eq!(result, "red line 1\ngreen line 2"); + } + + #[test] + fn handles_empty_output() { + assert_eq!(process_output_lines(""), ""); + } + + #[test] + fn handles_single_line() { + assert_eq!(process_output_lines("single line"), "single line"); + } + + #[test] + fn combines_ansi_stripping_and_line_limiting() { + let output = "\x1b[31mline 1\x1b[0m\n\x1b[32mline 2\x1b[0m\n\x1b[33mline 3\x1b[0m\n\x1b[34mline 4\x1b[0m"; + let result = process_output_lines(output); + assert_eq!(result, "line 3\nline 4"); + } + } + + mod format_task_timing { + use super::*; + use std::time::Duration; + + fn create_test_task_info_with_timing( + start: Option, + end: Option, + ) -> TaskInfo { + TaskInfo { + task: Task { + id: "test_task".to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status: TaskStatus::Running, + start_time: start, + end_time: end, + result: None, + current_output: String::new(), + } + } + + #[test] + fn returns_none_when_no_start_time() { + let task_info = create_test_task_info_with_timing(None, None); + let current_time = Instant::now(); + assert!(format_task_timing(&task_info, current_time).is_none()); + } + + #[test] + fn formats_running_task_duration() { + let start_time = Instant::now(); + let current_time = start_time + Duration::from_millis(1500); + let task_info = create_test_task_info_with_timing(Some(start_time), None); + + let result = format_task_timing(&task_info, current_time).unwrap(); + assert!(result.contains("1.5s")); + assert!(result.contains("⏱️")); + } + + #[test] + fn formats_completed_task_duration() { + let start_time = Instant::now(); + let end_time = start_time + Duration::from_millis(2500); + let current_time = Instant::now(); // This shouldn't matter for completed tasks + let task_info = create_test_task_info_with_timing(Some(start_time), Some(end_time)); + + let result = format_task_timing(&task_info, current_time).unwrap(); + assert!(result.contains("2.5s")); + assert!(result.contains("⏱️")); + } + } + + mod format_task_output { + use super::*; + + fn create_test_task_info_with_output(status: TaskStatus, output: &str) -> TaskInfo { + TaskInfo { + task: Task { + id: "test_task".to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status, + start_time: None, + end_time: None, + result: None, + current_output: output.to_string(), + } + } + + #[test] + fn returns_none_for_non_running_tasks() { + let task_info = create_test_task_info_with_output(TaskStatus::Pending, "some output"); + assert!(format_task_output(&task_info).is_none()); + + let task_info = create_test_task_info_with_output(TaskStatus::Completed, "some output"); + assert!(format_task_output(&task_info).is_none()); + + let task_info = create_test_task_info_with_output(TaskStatus::Failed, "some output"); + assert!(format_task_output(&task_info).is_none()); + } + + #[test] + fn returns_none_for_running_task_with_empty_output() { + let task_info = create_test_task_info_with_output(TaskStatus::Running, ""); + assert!(format_task_output(&task_info).is_none()); + } + + #[test] + fn formats_running_task_with_output() { + let task_info = create_test_task_info_with_output(TaskStatus::Running, "Building project..."); + let result = format_task_output(&task_info).unwrap(); + + assert!(result.contains("💬")); + assert!(result.contains("Building project...")); + } + + #[test] + fn replaces_newlines_with_pipes() { + let task_info = create_test_task_info_with_output(TaskStatus::Running, "line 1\nline 2\nline 3"); + let result = format_task_output(&task_info).unwrap(); + + assert!(result.contains("line 1 | line 2 | line 3")); + } + + #[test] + fn truncates_long_output() { + let long_output = "a".repeat(150); + let task_info = create_test_task_info_with_output(TaskStatus::Running, &long_output); + let result = format_task_output(&task_info).unwrap(); + + assert!(result.contains("...")); + assert!(result.len() < long_output.len() + 20); // Account for formatting + } + } + + mod format_task_error { + use super::*; + use crate::agents::sub_recipe_execution_tool::types::{TaskResult, TaskStatus}; + + fn create_test_task_info_with_error(error_msg: Option<&str>) -> TaskInfo { + let result = error_msg.map(|msg| TaskResult { + task_id: "test_task".to_string(), + status: TaskStatus::Failed, + data: None, + error: Some(msg.to_string()), + }); + + TaskInfo { + task: Task { + id: "test_task".to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status: TaskStatus::Failed, + start_time: None, + end_time: None, + result, + current_output: String::new(), + } + } + + #[test] + fn returns_none_when_no_error() { + let task_info = create_test_task_info_with_error(None); + assert!(format_task_error(&task_info).is_none()); + } + + #[test] + fn formats_error_message() { + let task_info = create_test_task_info_with_error(Some("File not found")); + let result = format_task_error(&task_info).unwrap(); + + assert!(result.contains("⚠️")); + assert!(result.contains("File not found")); + } + + #[test] + fn replaces_newlines_in_error() { + let task_info = create_test_task_info_with_error(Some("Error on line 1\nError on line 2")); + let result = format_task_error(&task_info).unwrap(); + + assert!(result.contains("Error on line 1 Error on line 2")); + } + + #[test] + fn truncates_long_error() { + let long_error = "error ".repeat(30); + let task_info = create_test_task_info_with_error(Some(&long_error)); + let result = format_task_error(&task_info).unwrap(); + + assert!(result.contains("...")); + assert!(result.len() < long_error.len() + 20); // Account for formatting + } + } + + mod format_task_display { + use super::*; + use std::time::Duration; + + fn create_comprehensive_task_info( + task_name: &str, + status: TaskStatus, + start_time: Option, + end_time: Option, + current_output: &str, + error: Option<&str>, + ) -> TaskInfo { + let result = error.map(|msg| crate::agents::sub_recipe_execution_tool::types::TaskResult { + task_id: task_name.to_string(), + status: status.clone(), + data: None, + error: Some(msg.to_string()), + }); + + TaskInfo { + task: Task { + id: task_name.to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status, + start_time, + end_time, + result, + current_output: current_output.to_string(), + } + } + + #[test] + fn formats_pending_task() { + let task_info = create_comprehensive_task_info( + "pending_task", + TaskStatus::Pending, + None, + None, + "", + None, + ); + let current_time = Instant::now(); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("⏳")); + assert!(result.contains("pending_task")); + assert!(result.contains("(test)")); + } + + #[test] + fn formats_running_task_with_output() { + let start_time = Instant::now(); + let current_time = start_time + Duration::from_secs(2); + let task_info = create_comprehensive_task_info( + "running_task", + TaskStatus::Running, + Some(start_time), + None, + "Compiling...", + None, + ); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("🏃")); + assert!(result.contains("running_task")); + assert!(result.contains("2.0s")); + assert!(result.contains("💬")); + assert!(result.contains("Compiling...")); + } + + #[test] + fn formats_failed_task_with_error() { + let start_time = Instant::now(); + let end_time = start_time + Duration::from_millis(1500); + let task_info = create_comprehensive_task_info( + "failed_task", + TaskStatus::Failed, + Some(start_time), + Some(end_time), + "", + Some("Compilation failed"), + ); + let current_time = Instant::now(); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("❌")); + assert!(result.contains("failed_task")); + assert!(result.contains("1.5s")); + assert!(result.contains("⚠️")); + assert!(result.contains("Compilation failed")); + } + + #[test] + fn formats_completed_task() { + let start_time = Instant::now(); + let end_time = start_time + Duration::from_secs(3); + let task_info = create_comprehensive_task_info( + "completed_task", + TaskStatus::Completed, + Some(start_time), + Some(end_time), + "", + None, + ); + let current_time = Instant::now(); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("✅")); + assert!(result.contains("completed_task")); + assert!(result.contains("3.0s")); + } + } From a02c0560dbc5f02dfd9bd0bb0e53d92d77f466ff Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 8 Jul 2025 19:39:58 +1000 Subject: [PATCH 08/34] fixed scenario that only run params are specified --- crates/goose/src/agents/recipe_tools/mod.rs | 1 + .../agents/recipe_tools/param_utils/mod.rs | 116 ++ .../agents/recipe_tools/param_utils/tests.rs | 236 ++++ .../agents/recipe_tools/sub_recipe_tools.rs | 72 +- .../recipe_tools/sub_recipe_tools/tests.rs | 174 --- .../sub_recipe_execution_tool/utils/tests.rs | 1065 +++++++++-------- .../sub_recipe_execution_tool/workers.rs | 7 +- 7 files changed, 893 insertions(+), 778 deletions(-) create mode 100644 crates/goose/src/agents/recipe_tools/param_utils/mod.rs create mode 100644 crates/goose/src/agents/recipe_tools/param_utils/tests.rs diff --git a/crates/goose/src/agents/recipe_tools/mod.rs b/crates/goose/src/agents/recipe_tools/mod.rs index 5f2f95fc8485..90603c88488e 100644 --- a/crates/goose/src/agents/recipe_tools/mod.rs +++ b/crates/goose/src/agents/recipe_tools/mod.rs @@ -1 +1,2 @@ +pub mod param_utils; pub mod sub_recipe_tools; diff --git a/crates/goose/src/agents/recipe_tools/param_utils/mod.rs b/crates/goose/src/agents/recipe_tools/param_utils/mod.rs new file mode 100644 index 000000000000..08db06c5e25e --- /dev/null +++ b/crates/goose/src/agents/recipe_tools/param_utils/mod.rs @@ -0,0 +1,116 @@ +use anyhow::Result; +use serde_json::Value; +use std::collections::HashMap; + +use crate::recipe::SubRecipe; + +pub fn extract_run_params( + sub_recipe: &SubRecipe, +) -> (HashMap, Vec>) { + let base_params = sub_recipe.values.clone().unwrap_or_default(); + + let run_params = sub_recipe + .executions + .as_ref() + .and_then(|e| e.runs.as_ref()) + .map(|runs| { + runs.iter() + .map(|run| { + let mut params = base_params.clone(); + if let Some(run_values) = &run.values { + params.extend(run_values.clone()); + } + params + }) + .collect::>() + }) + .unwrap_or_default(); + (base_params, run_params) +} + +pub fn validate_param_counts( + run_params: &[HashMap], + params_from_tool_call: &[Value], +) -> Result<()> { + if !run_params.is_empty() + && run_params.len() != params_from_tool_call.len() + && params_from_tool_call.len() > 1 + { + return Err(anyhow::anyhow!( + "The number of runs in the sub recipe ({}) does not match the number of task parameters ({})", + run_params.len(), + params_from_tool_call.len() + )); + } + Ok(()) +} + +pub fn prepare_base_params( + base_params: &HashMap, + run_params: &[HashMap], + params_from_tool_call: &[Value], +) -> Vec> { + if run_params.is_empty() { + vec![base_params.clone(); params_from_tool_call.len()] + } else { + run_params.to_vec() + } +} + +pub fn prepare_tool_params( + params_from_tool_call: &[Value], + run_params: &[HashMap], +) -> Vec { + if params_from_tool_call.len() == 1 && run_params.len() > 1 { + vec![params_from_tool_call[0].clone(); run_params.len()] + } else { + params_from_tool_call.to_vec() + } +} + +pub fn merge_parameters( + tool_params: Vec, + base_params: Vec>, +) -> Vec> { + tool_params + .into_iter() + .zip(base_params) + .map(|(tool_param, mut base_param_map)| { + if let Some(param_obj) = tool_param.as_object() { + for (key, value) in param_obj { + let value_str = value + .as_str() + .map(String::from) + .unwrap_or_else(|| value.to_string()); + base_param_map.entry(key.clone()).or_insert(value_str); + } + } + base_param_map + }) + .collect() +} + +pub fn prepare_command_params( + sub_recipe: &SubRecipe, + params_from_tool_call: Vec, +) -> Result>> { + let (base_params, run_params) = extract_run_params(sub_recipe); + + if params_from_tool_call.is_empty() { + return Ok(run_params); + } + + validate_param_counts(&run_params, ¶ms_from_tool_call)?; + + let base_params_for_merging = + prepare_base_params(&base_params, &run_params, ¶ms_from_tool_call); + let tool_params_for_merging = prepare_tool_params(¶ms_from_tool_call, &run_params); + + Ok(merge_parameters( + tool_params_for_merging, + base_params_for_merging, + )) +} + +#[cfg(test)] +mod tests; diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs new file mode 100644 index 000000000000..81e5c412a8c7 --- /dev/null +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -0,0 +1,236 @@ +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::recipe::{Execution, ExecutionRun, SubRecipe}; + use serde_json::json; + use serde_json::Value; + + use crate::agents::recipe_tools::param_utils::prepare_command_params; + + fn setup_default_sub_recipe() -> SubRecipe { + let sub_recipe = SubRecipe { + name: "test_sub_recipe".to_string(), + path: "test_sub_recipe.yaml".to_string(), + values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), + executions: None, + }; + sub_recipe + } + + fn create_execution_values(key: &str, values: Vec) -> Execution { + let runs = values + .iter() + .map(|value| ExecutionRun { + values: Some(HashMap::from([(key.to_string(), value.to_string())])), + }) + .collect(); + Execution { + parallel: true, + runs: Some(runs), + } + } + + mod prepare_command_params_tests { + use super::*; + + mod without_execution_runs { + use super::*; + + #[test] + fn test_return_command_param() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "value2".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_command_param_when_value_override_passed_param_value() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "different_value".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_empty_command_param() { + let parameter_array = vec![]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!(result.len(), 0); + } + } + + mod with_execution_runs { + use super::*; + + #[test] + fn test_return_command_param() { + let parameter_array = vec![json!(HashMap::from([( + "key3".to_string(), + "value3".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = + Some(create_execution_values("key2", vec!["value2".to_string()])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ("key3".to_string(), "value3".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_command_param_when_all_values_from_tool_call_parameters() { + let parameter_array = vec![ + json!(HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()) + ])), + json!(HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()) + ])), + ]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + sub_recipe.executions = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ]), + ], + result + ); + } + + #[test] + fn test_return_command_param_when_all_from_values_in_sub_recipe() { + let parameter_array = vec![]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string(), "key2_value2".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ("key3".to_string(), "value3".to_string()), + ]) + ], + result + ); + } + + #[test] + fn test_return_command_param_when_tool_call_parameters_has_one_item_and_execution_runs_has_multiple_items( + ) { + let parameter_array = vec![json!(HashMap::from([( + "key3".to_string(), + "value3".to_string() + ),]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string(), "key2_value2".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ("key3".to_string(), "value3".to_string()), + ]) + ], + result + ); + } + + #[test] + fn test_throw_error_when_execution_runs_value_length_not_match_with_tool_call_parameters( + ) { + let parameter_array = vec![ + json!(HashMap::from([("key3".to_string(), "value3".to_string())])), + json!(HashMap::from([("key4".to_string(), "value4".to_string())])), + ]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = + Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array); + + assert!(result.is_err()); + } + } + } +} diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index e50ee588700f..8cd1a8599037 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -1,5 +1,5 @@ use std::collections::HashSet; -use std::{collections::HashMap, fs}; +use std::fs; use anyhow::Result; use mcp_core::tool::{Tool, ToolAnnotations}; @@ -8,6 +8,8 @@ use serde_json::{json, Map, Value}; use crate::agents::sub_recipe_execution_tool::lib::Task; use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubRecipe}; +use super::param_utils::prepare_command_params; + pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { @@ -155,73 +157,5 @@ fn get_input_schema(sub_recipe: &SubRecipe) -> Result { Ok(create_input_schema(param_properties, param_required)) } -fn extract_run_params( - sub_recipe: &SubRecipe, -) -> (HashMap, Vec>) { - let base_params = sub_recipe.values.clone().unwrap_or_default(); - - let run_params = sub_recipe - .executions - .as_ref() - .and_then(|e| e.runs.as_ref()) - .map(|runs| { - runs.iter() - .map(|run| { - let mut params = base_params.clone(); - if let Some(run_values) = &run.values { - params.extend(run_values.clone()); - } - params - }) - .collect::>() - }) - .unwrap_or_default(); - (base_params, run_params) -} - -fn prepare_command_params( - sub_recipe: &SubRecipe, - params_from_tool_call: Vec, -) -> Result>> { - let (base_params, run_params) = extract_run_params(sub_recipe); - - if params_from_tool_call.is_empty() { - return Ok(run_params); - } - - if !run_params.is_empty() && run_params.len() != params_from_tool_call.len() { - return Err(anyhow::anyhow!( - "The number of runs in the sub recipe ({}) does not match the number of task parameters ({})", - run_params.len(), - params_from_tool_call.len() - )); - } - - let run_params_for_merging = if run_params.is_empty() { - vec![base_params; params_from_tool_call.len()] - } else { - run_params - }; - - let merged_params = params_from_tool_call - .into_iter() - .zip(run_params_for_merging) - .map(|(tool_param, mut run_param_map)| { - if let Some(param_obj) = tool_param.as_object() { - for (key, value) in param_obj { - let value_str = value - .as_str() - .map(String::from) - .unwrap_or_else(|| value.to_string()); - run_param_map.entry(key.clone()).or_insert(value_str); - } - } - run_param_map - }) - .collect(); - - Ok(merged_params) -} - #[cfg(test)] mod tests; diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index fac6a095bbfe..a4956f65edda 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -30,180 +30,6 @@ mod tests { } } - mod prepare_command_params_tests { - use super::*; - - use crate::agents::recipe_tools::sub_recipe_tools::{ - prepare_command_params, tests::tests::setup_default_sub_recipe, - }; - - mod without_execution_runs { - use super::*; - - #[test] - fn test_return_command_param() { - let parameter_array = vec![json!(HashMap::from([( - "key2".to_string(), - "value2".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()) - ]),], - result - ); - } - - #[test] - fn test_return_command_param_when_value_override_passed_param_value() { - let parameter_array = vec![json!(HashMap::from([( - "key2".to_string(), - "different_value".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = Some(HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()), - ])); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()) - ]),], - result - ); - } - - #[test] - fn test_return_empty_command_param() { - let parameter_array = vec![]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = None; - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!(result.len(), 0); - } - } - - mod with_execution_runs { - use super::*; - - #[test] - fn test_return_command_param() { - let parameter_array = vec![json!(HashMap::from([( - "key3".to_string(), - "value3".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - sub_recipe.executions = - Some(create_execution_values("key2", vec!["value2".to_string()])); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()), - ("key3".to_string(), "value3".to_string()) - ]),], - result - ); - } - - #[test] - fn test_return_command_param_when_all_values_from_tool_call_parameters() { - let parameter_array = vec![ - json!(HashMap::from([ - ("key1".to_string(), "key1_value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()) - ])), - json!(HashMap::from([ - ("key1".to_string(), "key1_value2".to_string()), - ("key2".to_string(), "key2_value2".to_string()) - ])), - ]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = None; - sub_recipe.executions = None; - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![ - HashMap::from([ - ("key1".to_string(), "key1_value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()), - ]), - HashMap::from([ - ("key1".to_string(), "key1_value2".to_string()), - ("key2".to_string(), "key2_value2".to_string()), - ]), - ], - result - ); - } - - #[test] - fn test_return_command_param_when_all_from_values_in_sub_recipe() { - let parameter_array = vec![]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = Some(HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key3".to_string(), "value3".to_string()), - ])); - sub_recipe.executions = Some(create_execution_values( - "key2", - vec!["key2_value1".to_string(), "key2_value2".to_string()], - )); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![ - HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()), - ("key3".to_string(), "value3".to_string()), - ]), - HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value2".to_string()), - ("key3".to_string(), "value3".to_string()), - ]) - ], - result - ); - } - - #[test] - fn test_throw_error_when_execution_runs_value_length_not_match_with_tool_call_parameters( - ) { - let parameter_array = vec![json!(HashMap::from([( - "key3".to_string(), - "value3".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - sub_recipe.executions = Some(create_execution_values( - "key2", - vec!["key2_value1".to_string(), "key2_value2".to_string()], - )); - - let result = prepare_command_params(&sub_recipe, parameter_array); - - assert!(result.is_err()); - } - } - } - mod get_input_schema { use super::*; use crate::agents::recipe_tools::sub_recipe_tools::get_input_schema; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs index b0d3e5e0dac6..88d823092c34 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -11,609 +11,610 @@ use tokio::time::Instant; mod truncate_with_ellipsis { use super::*; - #[test] - fn returns_original_when_under_limit() { - assert_eq!(truncate_with_ellipsis("hello", 10), "hello"); - assert_eq!(truncate_with_ellipsis("hi", 5), "hi"); - } + #[test] + fn returns_original_when_under_limit() { + assert_eq!(truncate_with_ellipsis("hello", 10), "hello"); + assert_eq!(truncate_with_ellipsis("hi", 5), "hi"); + } - #[test] - fn truncates_when_over_limit() { - assert_eq!(truncate_with_ellipsis("hello world", 5), "he..."); - assert_eq!( - truncate_with_ellipsis("very long text here", 10), - "very lo..." - ); - } + #[test] + fn truncates_when_over_limit() { + assert_eq!(truncate_with_ellipsis("hello world", 5), "he..."); + assert_eq!( + truncate_with_ellipsis("very long text here", 10), + "very lo..." + ); + } - #[test] - fn handles_empty_string() { - assert_eq!(truncate_with_ellipsis("", 5), ""); - } + #[test] + fn handles_empty_string() { + assert_eq!(truncate_with_ellipsis("", 5), ""); + } - #[test] - fn handles_exact_limit() { - assert_eq!(truncate_with_ellipsis("hello", 5), "hello"); - } + #[test] + fn handles_exact_limit() { + assert_eq!(truncate_with_ellipsis("hello", 5), "hello"); + } - #[test] - fn handles_very_short_limit() { - assert_eq!(truncate_with_ellipsis("hello", 3), "..."); - assert_eq!(truncate_with_ellipsis("hi", 2), "hi"); // Under limit, return as-is - } + #[test] + fn handles_very_short_limit() { + assert_eq!(truncate_with_ellipsis("hello", 3), "..."); + assert_eq!(truncate_with_ellipsis("hi", 2), "hi"); // Under limit, return as-is } +} - mod strip_ansi_codes { - use super::*; +mod strip_ansi_codes { + use super::*; - #[test] - fn preserves_plain_text() { - assert_eq!(strip_ansi_codes("hello world"), "hello world"); - assert_eq!(strip_ansi_codes("no ansi codes"), "no ansi codes"); - } + #[test] + fn preserves_plain_text() { + assert_eq!(strip_ansi_codes("hello world"), "hello world"); + assert_eq!(strip_ansi_codes("no ansi codes"), "no ansi codes"); + } - #[test] - fn removes_color_codes() { - assert_eq!(strip_ansi_codes("\x1b[31mred text\x1b[0m"), "red text"); - assert_eq!(strip_ansi_codes("\x1b[32mgreen\x1b[0m"), "green"); - } + #[test] + fn removes_color_codes() { + assert_eq!(strip_ansi_codes("\x1b[31mred text\x1b[0m"), "red text"); + assert_eq!(strip_ansi_codes("\x1b[32mgreen\x1b[0m"), "green"); + } - #[test] - fn removes_complex_formatting() { - assert_eq!( - strip_ansi_codes("\x1b[1;32mbold green\x1b[0m"), - "bold green" - ); - assert_eq!( - strip_ansi_codes("\x1b[4;31munderline red\x1b[0m"), - "underline red" - ); - } + #[test] + fn removes_complex_formatting() { + assert_eq!( + strip_ansi_codes("\x1b[1;32mbold green\x1b[0m"), + "bold green" + ); + assert_eq!( + strip_ansi_codes("\x1b[4;31munderline red\x1b[0m"), + "underline red" + ); + } - #[test] - fn handles_multiple_sequences() { - let input = "\x1b[31mred\x1b[0m normal \x1b[32mgreen\x1b[0m"; - assert_eq!(strip_ansi_codes(input), "red normal green"); - } + #[test] + fn handles_multiple_sequences() { + let input = "\x1b[31mred\x1b[0m normal \x1b[32mgreen\x1b[0m"; + assert_eq!(strip_ansi_codes(input), "red normal green"); + } - #[test] - fn handles_empty_string() { - assert_eq!(strip_ansi_codes(""), ""); - } + #[test] + fn handles_empty_string() { + assert_eq!(strip_ansi_codes(""), ""); + } - #[test] - fn handles_malformed_sequences() { - // Incomplete escape sequence - our function strips the \x1b char - assert_eq!(strip_ansi_codes("\x1b hello"), "hello"); - } + #[test] + fn handles_malformed_sequences() { + // Incomplete escape sequence - our function strips the \x1b char + assert_eq!(strip_ansi_codes("\x1b hello"), "hello"); } +} - mod get_task_name { - use super::*; - - #[test] - fn extracts_sub_recipe_name() { - let sub_recipe_task = Task { - id: "task_1".to_string(), - task_type: "sub_recipe".to_string(), - payload: json!({ - "sub_recipe": { - "name": "my_recipe", - "recipe_path": "/path/to/recipe" - } - }), - }; - - let task_info = TaskInfo { - task: sub_recipe_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; - - assert_eq!(get_task_name(&task_info), "my_recipe"); - } +mod get_task_name { + use super::*; - #[test] - fn falls_back_to_task_id_for_text_instruction() { - let text_task = Task { - id: "task_2".to_string(), - task_type: "text_instruction".to_string(), - payload: json!({"text_instruction": "do something"}), - }; - - let task_info = TaskInfo { - task: text_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; - - assert_eq!(get_task_name(&task_info), "task_2"); - } + #[test] + fn extracts_sub_recipe_name() { + let sub_recipe_task = Task { + id: "task_1".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({ + "sub_recipe": { + "name": "my_recipe", + "recipe_path": "/path/to/recipe" + } + }), + }; + + let task_info = TaskInfo { + task: sub_recipe_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "my_recipe"); + } - #[test] - fn falls_back_to_task_id_when_sub_recipe_name_missing() { - let malformed_task = Task { - id: "task_3".to_string(), - task_type: "sub_recipe".to_string(), - payload: json!({ - "sub_recipe": { - "recipe_path": "/path/to/recipe" - // missing "name" field - } - }), - }; - - let task_info = TaskInfo { - task: malformed_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; - - assert_eq!(get_task_name(&task_info), "task_3"); - } + #[test] + fn falls_back_to_task_id_for_text_instruction() { + let text_task = Task { + id: "task_2".to_string(), + task_type: "text_instruction".to_string(), + payload: json!({"text_instruction": "do something"}), + }; + + let task_info = TaskInfo { + task: text_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "task_2"); + } - #[test] - fn falls_back_to_task_id_when_sub_recipe_missing() { - let malformed_task = Task { - id: "task_4".to_string(), - task_type: "sub_recipe".to_string(), - payload: json!({}), // missing "sub_recipe" field - }; - - let task_info = TaskInfo { - task: malformed_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; - - assert_eq!(get_task_name(&task_info), "task_4"); - } + #[test] + fn falls_back_to_task_id_when_sub_recipe_name_missing() { + let malformed_task = Task { + id: "task_3".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({ + "sub_recipe": { + "recipe_path": "/path/to/recipe" + // missing "name" field + } + }), + }; + + let task_info = TaskInfo { + task: malformed_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "task_3"); } - mod count_by_status { - use super::*; - - fn create_test_task(id: &str, status: TaskStatus) -> TaskInfo { - TaskInfo { - task: Task { - id: id.to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - } - } + #[test] + fn falls_back_to_task_id_when_sub_recipe_missing() { + let malformed_task = Task { + id: "task_4".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({}), // missing "sub_recipe" field + }; + + let task_info = TaskInfo { + task: malformed_task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }; + + assert_eq!(get_task_name(&task_info), "task_4"); + } +} - #[test] - fn counts_empty_map() { - let tasks = HashMap::new(); - let (total, pending, running, completed, failed) = count_by_status(&tasks); - assert_eq!( - (total, pending, running, completed, failed), - (0, 0, 0, 0, 0) - ); - } +mod count_by_status { + use super::*; - #[test] - fn counts_single_status() { - let mut tasks = HashMap::new(); - tasks.insert( - "task1".to_string(), - create_test_task("task1", TaskStatus::Pending), - ); - tasks.insert( - "task2".to_string(), - create_test_task("task2", TaskStatus::Pending), - ); - - let (total, pending, running, completed, failed) = count_by_status(&tasks); - assert_eq!( - (total, pending, running, completed, failed), - (2, 2, 0, 0, 0) - ); + fn create_test_task(id: &str, status: TaskStatus) -> TaskInfo { + TaskInfo { + task: Task { + id: id.to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), } + } - #[test] - fn counts_mixed_statuses() { - let mut tasks = HashMap::new(); - tasks.insert( - "task1".to_string(), - create_test_task("task1", TaskStatus::Pending), - ); - tasks.insert( - "task2".to_string(), - create_test_task("task2", TaskStatus::Running), - ); - tasks.insert( - "task3".to_string(), - create_test_task("task3", TaskStatus::Completed), - ); - tasks.insert( - "task4".to_string(), - create_test_task("task4", TaskStatus::Failed), - ); - tasks.insert( - "task5".to_string(), - create_test_task("task5", TaskStatus::Completed), - ); - - let (total, pending, running, completed, failed) = count_by_status(&tasks); - assert_eq!( - (total, pending, running, completed, failed), - (5, 1, 1, 2, 1) - ); - } + #[test] + fn counts_empty_map() { + let tasks = HashMap::new(); + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (0, 0, 0, 0, 0) + ); } - mod get_status_icon { - use super::*; + #[test] + fn counts_single_status() { + let mut tasks = HashMap::new(); + tasks.insert( + "task1".to_string(), + create_test_task("task1", TaskStatus::Pending), + ); + tasks.insert( + "task2".to_string(), + create_test_task("task2", TaskStatus::Pending), + ); + + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (2, 2, 0, 0, 0) + ); + } - #[test] - fn returns_correct_icon_for_pending() { - assert_eq!(get_status_icon(&TaskStatus::Pending), "⏳"); - } + #[test] + fn counts_mixed_statuses() { + let mut tasks = HashMap::new(); + tasks.insert( + "task1".to_string(), + create_test_task("task1", TaskStatus::Pending), + ); + tasks.insert( + "task2".to_string(), + create_test_task("task2", TaskStatus::Running), + ); + tasks.insert( + "task3".to_string(), + create_test_task("task3", TaskStatus::Completed), + ); + tasks.insert( + "task4".to_string(), + create_test_task("task4", TaskStatus::Failed), + ); + tasks.insert( + "task5".to_string(), + create_test_task("task5", TaskStatus::Completed), + ); + + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (5, 1, 1, 2, 1) + ); + } +} - #[test] - fn returns_correct_icon_for_running() { - assert_eq!(get_status_icon(&TaskStatus::Running), "🏃"); - } +mod get_status_icon { + use super::*; - #[test] - fn returns_correct_icon_for_completed() { - assert_eq!(get_status_icon(&TaskStatus::Completed), "✅"); - } + #[test] + fn returns_correct_icon_for_pending() { + assert_eq!(get_status_icon(&TaskStatus::Pending), "⏳"); + } - #[test] - fn returns_correct_icon_for_failed() { - assert_eq!(get_status_icon(&TaskStatus::Failed), "❌"); - } + #[test] + fn returns_correct_icon_for_running() { + assert_eq!(get_status_icon(&TaskStatus::Running), "🏃"); } - mod process_output_lines { - use super::*; + #[test] + fn returns_correct_icon_for_completed() { + assert_eq!(get_status_icon(&TaskStatus::Completed), "✅"); + } - #[test] - fn preserves_short_output() { - let output = "line 1\nline 2"; - assert_eq!(process_output_lines(output), "line 1\nline 2"); - } + #[test] + fn returns_correct_icon_for_failed() { + assert_eq!(get_status_icon(&TaskStatus::Failed), "❌"); + } +} - #[test] - fn keeps_only_recent_lines_when_too_many() { - let output = "line 1\nline 2\nline 3\nline 4\nline 5"; - let result = process_output_lines(output); - assert_eq!(result, "line 4\nline 5"); - } +mod process_output_lines { + use super::*; - #[test] - fn strips_ansi_codes_from_output() { - let output = "\x1b[31mred line 1\x1b[0m\n\x1b[32mgreen line 2\x1b[0m"; - let result = process_output_lines(output); - assert_eq!(result, "red line 1\ngreen line 2"); - } + #[test] + fn preserves_short_output() { + let output = "line 1\nline 2"; + assert_eq!(process_output_lines(output), "line 1\nline 2"); + } - #[test] - fn handles_empty_output() { - assert_eq!(process_output_lines(""), ""); - } + #[test] + fn keeps_only_recent_lines_when_too_many() { + let output = "line 1\nline 2\nline 3\nline 4\nline 5"; + let result = process_output_lines(output); + assert_eq!(result, "line 4\nline 5"); + } - #[test] - fn handles_single_line() { - assert_eq!(process_output_lines("single line"), "single line"); - } + #[test] + fn strips_ansi_codes_from_output() { + let output = "\x1b[31mred line 1\x1b[0m\n\x1b[32mgreen line 2\x1b[0m"; + let result = process_output_lines(output); + assert_eq!(result, "red line 1\ngreen line 2"); + } - #[test] - fn combines_ansi_stripping_and_line_limiting() { - let output = "\x1b[31mline 1\x1b[0m\n\x1b[32mline 2\x1b[0m\n\x1b[33mline 3\x1b[0m\n\x1b[34mline 4\x1b[0m"; - let result = process_output_lines(output); - assert_eq!(result, "line 3\nline 4"); - } + #[test] + fn handles_empty_output() { + assert_eq!(process_output_lines(""), ""); } - mod format_task_timing { - use super::*; - use std::time::Duration; - - fn create_test_task_info_with_timing( - start: Option, - end: Option, - ) -> TaskInfo { - TaskInfo { - task: Task { - id: "test_task".to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status: TaskStatus::Running, - start_time: start, - end_time: end, - result: None, - current_output: String::new(), - } - } + #[test] + fn handles_single_line() { + assert_eq!(process_output_lines("single line"), "single line"); + } - #[test] - fn returns_none_when_no_start_time() { - let task_info = create_test_task_info_with_timing(None, None); - let current_time = Instant::now(); - assert!(format_task_timing(&task_info, current_time).is_none()); - } + #[test] + fn combines_ansi_stripping_and_line_limiting() { + let output = "\x1b[31mline 1\x1b[0m\n\x1b[32mline 2\x1b[0m\n\x1b[33mline 3\x1b[0m\n\x1b[34mline 4\x1b[0m"; + let result = process_output_lines(output); + assert_eq!(result, "line 3\nline 4"); + } +} - #[test] - fn formats_running_task_duration() { - let start_time = Instant::now(); - let current_time = start_time + Duration::from_millis(1500); - let task_info = create_test_task_info_with_timing(Some(start_time), None); - - let result = format_task_timing(&task_info, current_time).unwrap(); - assert!(result.contains("1.5s")); - assert!(result.contains("⏱️")); +mod format_task_timing { + use super::*; + use std::time::Duration; + + fn create_test_task_info_with_timing(start: Option, end: Option) -> TaskInfo { + TaskInfo { + task: Task { + id: "test_task".to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status: TaskStatus::Running, + start_time: start, + end_time: end, + result: None, + current_output: String::new(), } + } - #[test] - fn formats_completed_task_duration() { - let start_time = Instant::now(); - let end_time = start_time + Duration::from_millis(2500); - let current_time = Instant::now(); // This shouldn't matter for completed tasks - let task_info = create_test_task_info_with_timing(Some(start_time), Some(end_time)); - - let result = format_task_timing(&task_info, current_time).unwrap(); - assert!(result.contains("2.5s")); - assert!(result.contains("⏱️")); - } + #[test] + fn returns_none_when_no_start_time() { + let task_info = create_test_task_info_with_timing(None, None); + let current_time = Instant::now(); + assert!(format_task_timing(&task_info, current_time).is_none()); } - mod format_task_output { - use super::*; - - fn create_test_task_info_with_output(status: TaskStatus, output: &str) -> TaskInfo { - TaskInfo { - task: Task { - id: "test_task".to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status, - start_time: None, - end_time: None, - result: None, - current_output: output.to_string(), - } - } + #[test] + fn formats_running_task_duration() { + let start_time = Instant::now(); + let current_time = start_time + Duration::from_millis(1500); + let task_info = create_test_task_info_with_timing(Some(start_time), None); - #[test] - fn returns_none_for_non_running_tasks() { - let task_info = create_test_task_info_with_output(TaskStatus::Pending, "some output"); - assert!(format_task_output(&task_info).is_none()); + let result = format_task_timing(&task_info, current_time).unwrap(); + assert!(result.contains("1.5s")); + assert!(result.contains("⏱️")); + } - let task_info = create_test_task_info_with_output(TaskStatus::Completed, "some output"); - assert!(format_task_output(&task_info).is_none()); + #[test] + fn formats_completed_task_duration() { + let start_time = Instant::now(); + let end_time = start_time + Duration::from_millis(2500); + let current_time = Instant::now(); // This shouldn't matter for completed tasks + let task_info = create_test_task_info_with_timing(Some(start_time), Some(end_time)); - let task_info = create_test_task_info_with_output(TaskStatus::Failed, "some output"); - assert!(format_task_output(&task_info).is_none()); - } + let result = format_task_timing(&task_info, current_time).unwrap(); + assert!(result.contains("2.5s")); + assert!(result.contains("⏱️")); + } +} - #[test] - fn returns_none_for_running_task_with_empty_output() { - let task_info = create_test_task_info_with_output(TaskStatus::Running, ""); - assert!(format_task_output(&task_info).is_none()); - } +mod format_task_output { + use super::*; - #[test] - fn formats_running_task_with_output() { - let task_info = create_test_task_info_with_output(TaskStatus::Running, "Building project..."); - let result = format_task_output(&task_info).unwrap(); - - assert!(result.contains("💬")); - assert!(result.contains("Building project...")); + fn create_test_task_info_with_output(status: TaskStatus, output: &str) -> TaskInfo { + TaskInfo { + task: Task { + id: "test_task".to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status, + start_time: None, + end_time: None, + result: None, + current_output: output.to_string(), } + } - #[test] - fn replaces_newlines_with_pipes() { - let task_info = create_test_task_info_with_output(TaskStatus::Running, "line 1\nline 2\nline 3"); - let result = format_task_output(&task_info).unwrap(); - - assert!(result.contains("line 1 | line 2 | line 3")); - } + #[test] + fn returns_none_for_non_running_tasks() { + let task_info = create_test_task_info_with_output(TaskStatus::Pending, "some output"); + assert!(format_task_output(&task_info).is_none()); - #[test] - fn truncates_long_output() { - let long_output = "a".repeat(150); - let task_info = create_test_task_info_with_output(TaskStatus::Running, &long_output); - let result = format_task_output(&task_info).unwrap(); - - assert!(result.contains("...")); - assert!(result.len() < long_output.len() + 20); // Account for formatting - } + let task_info = create_test_task_info_with_output(TaskStatus::Completed, "some output"); + assert!(format_task_output(&task_info).is_none()); + + let task_info = create_test_task_info_with_output(TaskStatus::Failed, "some output"); + assert!(format_task_output(&task_info).is_none()); } - mod format_task_error { - use super::*; - use crate::agents::sub_recipe_execution_tool::types::{TaskResult, TaskStatus}; + #[test] + fn returns_none_for_running_task_with_empty_output() { + let task_info = create_test_task_info_with_output(TaskStatus::Running, ""); + assert!(format_task_output(&task_info).is_none()); + } - fn create_test_task_info_with_error(error_msg: Option<&str>) -> TaskInfo { - let result = error_msg.map(|msg| TaskResult { - task_id: "test_task".to_string(), - status: TaskStatus::Failed, - data: None, - error: Some(msg.to_string()), - }); - - TaskInfo { - task: Task { - id: "test_task".to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status: TaskStatus::Failed, - start_time: None, - end_time: None, - result, - current_output: String::new(), - } - } + #[test] + fn formats_running_task_with_output() { + let task_info = + create_test_task_info_with_output(TaskStatus::Running, "Building project..."); + let result = format_task_output(&task_info).unwrap(); - #[test] - fn returns_none_when_no_error() { - let task_info = create_test_task_info_with_error(None); - assert!(format_task_error(&task_info).is_none()); - } + assert!(result.contains("💬")); + assert!(result.contains("Building project...")); + } - #[test] - fn formats_error_message() { - let task_info = create_test_task_info_with_error(Some("File not found")); - let result = format_task_error(&task_info).unwrap(); - - assert!(result.contains("⚠️")); - assert!(result.contains("File not found")); - } + #[test] + fn replaces_newlines_with_pipes() { + let task_info = + create_test_task_info_with_output(TaskStatus::Running, "line 1\nline 2\nline 3"); + let result = format_task_output(&task_info).unwrap(); - #[test] - fn replaces_newlines_in_error() { - let task_info = create_test_task_info_with_error(Some("Error on line 1\nError on line 2")); - let result = format_task_error(&task_info).unwrap(); - - assert!(result.contains("Error on line 1 Error on line 2")); - } + assert!(result.contains("line 1 | line 2 | line 3")); + } - #[test] - fn truncates_long_error() { - let long_error = "error ".repeat(30); - let task_info = create_test_task_info_with_error(Some(&long_error)); - let result = format_task_error(&task_info).unwrap(); - - assert!(result.contains("...")); - assert!(result.len() < long_error.len() + 20); // Account for formatting + #[test] + fn truncates_long_output() { + let long_output = "a".repeat(150); + let task_info = create_test_task_info_with_output(TaskStatus::Running, &long_output); + let result = format_task_output(&task_info).unwrap(); + + assert!(result.contains("...")); + assert!(result.len() < long_output.len() + 20); // Account for formatting + } +} + +mod format_task_error { + use super::*; + use crate::agents::sub_recipe_execution_tool::types::{TaskResult, TaskStatus}; + + fn create_test_task_info_with_error(error_msg: Option<&str>) -> TaskInfo { + let result = error_msg.map(|msg| TaskResult { + task_id: "test_task".to_string(), + status: TaskStatus::Failed, + data: None, + error: Some(msg.to_string()), + }); + + TaskInfo { + task: Task { + id: "test_task".to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status: TaskStatus::Failed, + start_time: None, + end_time: None, + result, + current_output: String::new(), } } - mod format_task_display { - use super::*; - use std::time::Duration; + #[test] + fn returns_none_when_no_error() { + let task_info = create_test_task_info_with_error(None); + assert!(format_task_error(&task_info).is_none()); + } + + #[test] + fn formats_error_message() { + let task_info = create_test_task_info_with_error(Some("File not found")); + let result = format_task_error(&task_info).unwrap(); + + assert!(result.contains("⚠️")); + assert!(result.contains("File not found")); + } + + #[test] + fn replaces_newlines_in_error() { + let task_info = create_test_task_info_with_error(Some("Error on line 1\nError on line 2")); + let result = format_task_error(&task_info).unwrap(); + + assert!(result.contains("Error on line 1 Error on line 2")); + } - fn create_comprehensive_task_info( - task_name: &str, - status: TaskStatus, - start_time: Option, - end_time: Option, - current_output: &str, - error: Option<&str>, - ) -> TaskInfo { - let result = error.map(|msg| crate::agents::sub_recipe_execution_tool::types::TaskResult { + #[test] + fn truncates_long_error() { + let long_error = "error ".repeat(30); + let task_info = create_test_task_info_with_error(Some(&long_error)); + let result = format_task_error(&task_info).unwrap(); + + assert!(result.contains("...")); + assert!(result.len() < long_error.len() + 20); // Account for formatting + } +} + +mod format_task_display { + use super::*; + use std::time::Duration; + + fn create_comprehensive_task_info( + task_name: &str, + status: TaskStatus, + start_time: Option, + end_time: Option, + current_output: &str, + error: Option<&str>, + ) -> TaskInfo { + let result = error.map( + |msg| crate::agents::sub_recipe_execution_tool::types::TaskResult { task_id: task_name.to_string(), status: status.clone(), data: None, error: Some(msg.to_string()), - }); - - TaskInfo { - task: Task { - id: task_name.to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status, - start_time, - end_time, - result, - current_output: current_output.to_string(), - } + }, + ); + + TaskInfo { + task: Task { + id: task_name.to_string(), + task_type: "test".to_string(), + payload: json!({}), + }, + status, + start_time, + end_time, + result, + current_output: current_output.to_string(), } + } - #[test] - fn formats_pending_task() { - let task_info = create_comprehensive_task_info( - "pending_task", - TaskStatus::Pending, - None, - None, - "", - None, - ); - let current_time = Instant::now(); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("⏳")); - assert!(result.contains("pending_task")); - assert!(result.contains("(test)")); - } + #[test] + fn formats_pending_task() { + let task_info = create_comprehensive_task_info( + "pending_task", + TaskStatus::Pending, + None, + None, + "", + None, + ); + let current_time = Instant::now(); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("⏳")); + assert!(result.contains("pending_task")); + assert!(result.contains("(test)")); + } - #[test] - fn formats_running_task_with_output() { - let start_time = Instant::now(); - let current_time = start_time + Duration::from_secs(2); - let task_info = create_comprehensive_task_info( - "running_task", - TaskStatus::Running, - Some(start_time), - None, - "Compiling...", - None, - ); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("🏃")); - assert!(result.contains("running_task")); - assert!(result.contains("2.0s")); - assert!(result.contains("💬")); - assert!(result.contains("Compiling...")); - } + #[test] + fn formats_running_task_with_output() { + let start_time = Instant::now(); + let current_time = start_time + Duration::from_secs(2); + let task_info = create_comprehensive_task_info( + "running_task", + TaskStatus::Running, + Some(start_time), + None, + "Compiling...", + None, + ); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("🏃")); + assert!(result.contains("running_task")); + assert!(result.contains("2.0s")); + assert!(result.contains("💬")); + assert!(result.contains("Compiling...")); + } - #[test] - fn formats_failed_task_with_error() { - let start_time = Instant::now(); - let end_time = start_time + Duration::from_millis(1500); - let task_info = create_comprehensive_task_info( - "failed_task", - TaskStatus::Failed, - Some(start_time), - Some(end_time), - "", - Some("Compilation failed"), - ); - let current_time = Instant::now(); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("❌")); - assert!(result.contains("failed_task")); - assert!(result.contains("1.5s")); - assert!(result.contains("⚠️")); - assert!(result.contains("Compilation failed")); - } + #[test] + fn formats_failed_task_with_error() { + let start_time = Instant::now(); + let end_time = start_time + Duration::from_millis(1500); + let task_info = create_comprehensive_task_info( + "failed_task", + TaskStatus::Failed, + Some(start_time), + Some(end_time), + "", + Some("Compilation failed"), + ); + let current_time = Instant::now(); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("❌")); + assert!(result.contains("failed_task")); + assert!(result.contains("1.5s")); + assert!(result.contains("⚠️")); + assert!(result.contains("Compilation failed")); + } - #[test] - fn formats_completed_task() { - let start_time = Instant::now(); - let end_time = start_time + Duration::from_secs(3); - let task_info = create_comprehensive_task_info( - "completed_task", - TaskStatus::Completed, - Some(start_time), - Some(end_time), - "", - None, - ); - let current_time = Instant::now(); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("✅")); - assert!(result.contains("completed_task")); - assert!(result.contains("3.0s")); - } + #[test] + fn formats_completed_task() { + let start_time = Instant::now(); + let end_time = start_time + Duration::from_secs(3); + let task_info = create_comprehensive_task_info( + "completed_task", + TaskStatus::Completed, + Some(start_time), + Some(end_time), + "", + None, + ); + let current_time = Instant::now(); + let result = format_task_display(&task_info, current_time); + + assert!(result.contains("✅")); + assert!(result.contains("completed_task")); + assert!(result.contains("3.0s")); } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index d9f750480432..73ea8b5ab372 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -5,6 +5,10 @@ use std::sync::Arc; #[cfg(test)] mod tests { + use std::sync::atomic::AtomicUsize; + + use tokio::sync::mpsc; + use super::*; use crate::agents::sub_recipe_execution_tool::types::Task; @@ -18,8 +22,6 @@ mod tests { task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), result_sender: result_tx, active_workers: Arc::new(AtomicUsize::new(0)), - should_stop: Arc::new(AtomicBool::new(false)), - completed_tasks: Arc::new(AtomicUsize::new(0)), dashboard: None, }); @@ -30,7 +32,6 @@ mod tests { assert!(!handle.is_finished()); // Signal stop and close the channel to let the worker exit - shared_state.should_stop.store(true, Ordering::SeqCst); drop(task_tx); // Close the channel // Wait for the worker to finish From f7832da99126d25f8f9c2dee0bc4545aaf10059d Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 8 Jul 2025 21:35:10 +1000 Subject: [PATCH 09/34] display param values in the task list --- .../agents/sub_recipe_execution_tool/tasks.rs | 34 ++++++------ .../agents/sub_recipe_execution_tool/types.rs | 40 +++++++++++++- .../sub_recipe_execution_tool/utils/mod.rs | 54 ++++++++++++------- 3 files changed, 93 insertions(+), 35 deletions(-) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index 9543f0cec504..ceac281937e2 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -58,29 +58,31 @@ async fn execute_task(task: Task, dashboard: Option>) -> Resu fn build_command(task: &Task) -> Result<(Command, String), String> { let mut output_identifier = task.id.clone(); let mut command = if task.task_type == "sub_recipe" { - let sub_recipe = task.payload.get("sub_recipe").unwrap(); - let sub_recipe_name = sub_recipe.get("name").unwrap().as_str().unwrap(); - let path = sub_recipe.get("recipe_path").unwrap().as_str().unwrap(); - let command_parameters = sub_recipe.get("command_parameters").unwrap(); + let sub_recipe_name = task + .get_sub_recipe_name() + .ok_or("Missing sub_recipe name")?; + let path = task + .get_sub_recipe_path() + .ok_or("Missing sub_recipe path")?; + let command_parameters = task + .get_command_parameters() + .ok_or("Missing command_parameters")?; + output_identifier = format!("sub-recipe {}", sub_recipe_name); let mut cmd = Command::new("goose"); cmd.arg("run").arg("--recipe").arg(path); - if let Some(params_map) = command_parameters.as_object() { - for (key, value) in params_map { - let key_str = key.to_string(); - let value_str = value.as_str().unwrap_or(&value.to_string()).to_string(); - cmd.arg("--params") - .arg(format!("{}={}", key_str, value_str)); - } + + for (key, value) in command_parameters { + let key_str = key.to_string(); + let value_str = value.as_str().unwrap_or(&value.to_string()).to_string(); + cmd.arg("--params") + .arg(format!("{}={}", key_str, value_str)); } cmd } else { let text = task - .payload - .get("text_instruction") - .unwrap() - .as_str() - .unwrap(); + .get_text_instruction() + .ok_or("Missing text_instruction")?; let mut cmd = Command::new("goose"); cmd.arg("run").arg("--text").arg(text); cmd diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index 2f40490a564c..ed53c7ea08e3 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde_json::{Map, Value}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::sync::mpsc; @@ -13,6 +13,44 @@ pub struct Task { pub payload: Value, } +impl Task { + pub fn get_sub_recipe(&self) -> Option<&Map> { + if self.task_type == "sub_recipe" { + self.payload.get("sub_recipe").and_then(|sr| sr.as_object()) + } else { + None + } + } + + pub fn get_command_parameters(&self) -> Option<&Map> { + self.get_sub_recipe() + .and_then(|sr| sr.get("command_parameters")) + .and_then(|cp| cp.as_object()) + } + + pub fn get_sub_recipe_name(&self) -> Option<&str> { + self.get_sub_recipe() + .and_then(|sr| sr.get("name")) + .and_then(|name| name.as_str()) + } + + pub fn get_sub_recipe_path(&self) -> Option<&str> { + self.get_sub_recipe() + .and_then(|sr| sr.get("recipe_path")) + .and_then(|path| path.as_str()) + } + + pub fn get_text_instruction(&self) -> Option<&str> { + if self.task_type != "sub_recipe" { + self.payload + .get("text_instruction") + .and_then(|text| text.as_str()) + } else { + None + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TaskResult { pub task_id: String, diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index 31f20273ec9e..114febb71ff3 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -1,3 +1,4 @@ +use serde_json::{Map, Value}; use std::collections::HashMap; use tokio::time::Instant; @@ -10,17 +11,14 @@ const ERROR_PREVIEW_LENGTH: usize = 80; const CLEAR_TO_EOL: &str = "\x1b[K"; pub fn get_task_name(task_info: &TaskInfo) -> &str { - if task_info.task.task_type == "sub_recipe" { - task_info - .task - .payload - .get("sub_recipe") - .and_then(|sr| sr.get("name")) - .and_then(|n| n.as_str()) - .unwrap_or(&task_info.task.id) - } else { - &task_info.task.id - } + task_info + .task + .get_sub_recipe_name() + .unwrap_or(&task_info.task.id) +} + +pub fn get_command_parameters(task_info: &TaskInfo) -> Option<&Map> { + task_info.task.get_command_parameters() } pub fn truncate_with_ellipsis(text: &str, max_len: usize) -> String { @@ -68,9 +66,7 @@ pub fn strip_ansi_codes(text: &str) -> String { result } -// Pure utility functions for dashboard rendering -/// Get status icon for a given task status pub fn get_status_icon(status: &TaskStatus) -> &'static str { match status { TaskStatus::Pending => "⏳", @@ -80,7 +76,6 @@ pub fn get_status_icon(status: &TaskStatus) -> &'static str { } } -/// Process output lines, keeping only recent lines and stripping ANSI codes pub fn process_output_lines(output: &str) -> String { let lines: Vec<&str> = output.lines().collect(); let recent_lines = if lines.len() > MAX_OUTPUT_LINES { @@ -139,7 +134,29 @@ pub fn format_task_error(task_info: &TaskInfo) -> Option { }) } -/// Format complete task display +pub fn format_command_parameters(task_info: &TaskInfo) -> Option { + get_command_parameters(task_info).map(|params| { + if params.is_empty() { + return format!(" 📋 Parameters: (none){}\n", CLEAR_TO_EOL); + } + + let params_str = params + .iter() + .map(|(key, value)| { + let value_str = match value { + Value::String(s) => s.clone(), + _ => value.to_string(), + }; + format!("{}={}", key, value_str) + }) + .collect::>() + .join(", "); + + let params_preview = truncate_with_ellipsis(¶ms_str, OUTPUT_PREVIEW_LENGTH); + format!(" 📋 Parameters: {}{}\n", params_preview, CLEAR_TO_EOL) + }) +} + pub fn format_task_display(task_info: &TaskInfo, current_time: Instant) -> String { let mut display = String::new(); @@ -153,12 +170,14 @@ pub fn format_task_display(task_info: &TaskInfo, current_time: Instant) -> Strin status_icon, task_name, task_info.task.task_type, CLEAR_TO_EOL )); - // Task timing + if let Some(params) = format_command_parameters(task_info) { + display.push_str(¶ms); + } + if let Some(timing) = format_task_timing(task_info, current_time) { display.push_str(&timing); } - // Task output (if running) if let Some(output) = format_task_output(task_info) { display.push_str(&output); } @@ -168,7 +187,6 @@ pub fn format_task_display(task_info: &TaskInfo, current_time: Instant) -> Strin display.push_str(&error); } - // Empty line display.push_str(&format!( "{} ", From e601ca2a28dc3201964e90bfbb5e475e5857d3eb Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 02:06:50 +1000 Subject: [PATCH 10/34] use stream to send the output to session --- crates/goose-cli/src/session/mod.rs | 21 ++- .../sub_recipe_execution_tool/dashboard.rs | 146 ++++++++++++----- .../sub_recipe_execution_tool/executor.rs | 155 ++++++++---------- .../agents/sub_recipe_execution_tool/lib.rs | 18 +- .../sub_recipe_execute_task_tool.rs | 40 +++-- .../agents/sub_recipe_execution_tool/tasks.rs | 24 ++- .../agents/sub_recipe_execution_tool/types.rs | 2 +- .../sub_recipe_execution_tool/utils/mod.rs | 1 - .../sub_recipe_execution_tool/workers.rs | 61 +------ crates/goose/src/agents/sub_recipe_manager.rs | 25 --- 10 files changed, 253 insertions(+), 240 deletions(-) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index e978fbe12beb..6efcf16187ee 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -15,6 +15,7 @@ use goose::permission::Permission; use goose::permission::PermissionConfirmation; use goose::providers::base::Provider; pub use goose::session::Identifier; +use std::io::Write; use anyhow::{Context, Result}; use completion::GooseCompleter; @@ -1011,8 +1012,11 @@ impl Session { }; (formatted, Some(subagent_id.to_string()), Some(notification_type.to_string())) } else if let Some(Value::String(output)) = o.get("output") { - // Fallback for other MCP notification types + // Shell tool notification (output.to_owned(), None, None) + } else if let Some(Value::String(display)) = o.get("display") { + // Dashboard notification - return raw display content with ANSI codes + (display.to_owned(), None, Some("dashboard".to_string())) } else { (data.to_string(), None, None) } @@ -1022,6 +1026,21 @@ impl Session { }, }; + // Handle dashboard notifications specially - print raw content with ANSI codes + if let Some(ref notification_type) = _notification_type { + if notification_type == "dashboard" { + if interactive { + let _ = progress_bars.hide(); + print!("{}", formatted_message); + std::io::stdout().flush().unwrap(); + } else { + print!("{}", formatted_message); + std::io::stdout().flush().unwrap(); + } + continue; // Skip the normal notification handling below + } + } + // Handle subagent notifications - show immediately if let Some(_id) = subagent_id { // Show subagent notifications immediately (no buffering) with compact spacing diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs index ee8ae16bd4b5..572c82446073 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs @@ -1,14 +1,21 @@ +use mcp_core::protocol::{JsonRpcMessage, JsonRpcNotification}; +use serde_json::json; use std::collections::HashMap; -use std::io::{self, Write}; use std::sync::Arc; -use tokio::sync::RwLock; +use tokio::sync::{mpsc, RwLock}; use tokio::time::{sleep, Duration, Instant}; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskResult, TaskStatus}; use crate::agents::sub_recipe_execution_tool::utils::{ - count_by_status, format_task_display, get_task_name, process_output_lines, + count_by_status, format_task_display, get_task_name, }; +#[derive(Debug, Clone, PartialEq)] +pub enum DisplayMode { + Dashboard, + SingleTaskOutput, +} + const THROTTLE_INTERVAL_MS: u64 = 1000; const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H"; const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H"; @@ -20,10 +27,16 @@ pub struct TaskDashboard { last_display: Arc>, last_refresh: Arc>, initial_display_shown: Arc>, + notifier: mpsc::Sender, + display_mode: DisplayMode, } impl TaskDashboard { - pub fn new(tasks: Vec) -> Self { + pub fn new( + tasks: Vec, + display_mode: DisplayMode, + notifier: mpsc::Sender, + ) -> Self { let task_map = tasks .into_iter() .map(|task| { @@ -47,6 +60,8 @@ impl TaskDashboard { last_display: Arc::new(RwLock::new(String::new())), last_refresh: Arc::new(RwLock::new(Instant::now())), initial_display_shown: Arc::new(RwLock::new(false)), + notifier, + display_mode, } } @@ -71,15 +86,34 @@ impl TaskDashboard { self.refresh_display().await; } - pub async fn update_task_output(&self, task_id: &str, output: &str) { - let mut tasks = self.tasks.write().await; - if let Some(task_info) = tasks.get_mut(task_id) { - task_info.current_output = process_output_lines(output); - } - drop(tasks); + pub async fn send_live_output(&self, task_id: &str, line: &str) { + match self.display_mode { + DisplayMode::SingleTaskOutput => { + let _ = self + .notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": { + "type": "dashboard", + "display": format!("{}\n", line) + } + })), + })); + } + DisplayMode::Dashboard => { + let mut tasks = self.tasks.write().await; + if let Some(task_info) = tasks.get_mut(task_id) { + task_info.current_output.push_str(line); + task_info.current_output.push('\n'); + } + drop(tasks); - if !self.should_throttle_refresh().await { - self.refresh_display().await; + if !self.should_throttle_refresh().await { + self.refresh_display().await; + } + } } } @@ -123,64 +157,98 @@ impl TaskDashboard { async fn update_display_if_changed(&self, display: String) { let mut last_display = self.last_display.write().await; if *last_display != display { - print!("{}", display); - io::stdout().flush().unwrap(); + let _ = self + .notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": { + "type": "dashboard", + "display": display.clone() + } + })), + })); *last_display = display; } } pub async fn refresh_display(&self) { - let tasks = self.tasks.read().await; - let mut display = String::new(); + match self.display_mode { + DisplayMode::Dashboard => { + let tasks = self.tasks.read().await; + let mut display = String::new(); - let mut initial_shown = self.initial_display_shown.write().await; - self.render_header(&mut display, &mut initial_shown); - drop(initial_shown); + let mut initial_shown = self.initial_display_shown.write().await; + self.render_header(&mut display, &mut initial_shown); + drop(initial_shown); - self.render_progress_line(&mut display, &tasks); + self.render_progress_line(&mut display, &tasks); - let mut task_list: Vec<_> = tasks.values().collect(); - task_list.sort_by_key(|t| &t.task.id); + let mut task_list: Vec<_> = tasks.values().collect(); + task_list.sort_by_key(|t| &t.task.id); - for task_info in task_list { - self.render_task(&mut display, task_info); - } + for task_info in task_list { + self.render_task(&mut display, task_info); + } - display.push_str(CLEAR_BELOW); + display.push_str(CLEAR_BELOW); - self.update_display_if_changed(display).await; + self.update_display_if_changed(display).await; + } + DisplayMode::SingleTaskOutput => { + // No dashboard display needed for single task output mode + // Live output is handled via send_live_output method + } + } } pub async fn show_final_summary(&self) { let tasks = self.tasks.read().await; - println!("Execution Complete!"); - println!("═══════════════════════"); + let mut summary = String::new(); + summary.push_str("Execution Complete!\n"); + summary.push_str("═══════════════════════\n"); let (total, _, _, completed, failed) = count_by_status(&tasks); - println!("Total Tasks: {}", total); - println!("✅ Completed: {}", completed); - println!("❌ Failed: {}", failed); - println!( - "📈 Success Rate: {:.1}%", + summary.push_str(&format!("Total Tasks: {}\n", total)); + summary.push_str(&format!("✅ Completed: {}\n", completed)); + summary.push_str(&format!("❌ Failed: {}\n", failed)); + summary.push_str(&format!( + "📈 Success Rate: {:.1}%\n", (completed as f64 / total as f64) * 100.0 - ); + )); if failed > 0 { - println!("\n❌ Failed Tasks:"); + summary.push_str("\n❌ Failed Tasks:\n"); for task_info in tasks.values() { if matches!(task_info.status, TaskStatus::Failed) { let task_name = get_task_name(task_info); - println!(" • {}", task_name); + summary.push_str(&format!(" • {}\n", task_name)); if let Some(error) = task_info.error() { - println!(" Error: {}", error); + summary.push_str(&format!(" Error: {}\n", error)); } } } } - println!("\n📝 Generating summary..."); + summary.push_str("\n📝 Generating summary...\n"); + + // Send the final summary via notification + let _ = self + .notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": { + "type": "dashboard", + "display": summary + } + })), + })); + sleep(Duration::from_millis(500)).await; } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs index 9d869666e4af..3e1974323255 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs @@ -1,9 +1,10 @@ +use mcp_core::protocol::JsonRpcMessage; use std::sync::atomic::AtomicUsize; use std::sync::Arc; use tokio::sync::mpsc; use tokio::time::Instant; -use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; +use crate::agents::sub_recipe_execution_tool::dashboard::{DisplayMode, TaskDashboard}; use crate::agents::sub_recipe_execution_tool::lib::{ Config, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; @@ -12,10 +13,18 @@ use crate::agents::sub_recipe_execution_tool::workers::spawn_worker; const EXECUTION_STATUS_COMPLETED: &str = "completed"; -pub async fn execute_single_task(task: &Task, config: Config) -> ExecutionResponse { +pub async fn execute_single_task( + task: &Task, + config: Config, + notifier: mpsc::Sender, +) -> ExecutionResponse { let start_time = Instant::now(); - let result = process_task(task, config.timeout_seconds).await; - + let dashboard = Arc::new(TaskDashboard::new( + vec![task.clone()], + DisplayMode::SingleTaskOutput, + notifier, + )); + let result = process_task(task, config.timeout_seconds, dashboard).await; let execution_time = start_time.elapsed().as_millis(); let stats = calculate_stats(&[result.clone()], execution_time); @@ -26,6 +35,62 @@ pub async fn execute_single_task(task: &Task, config: Config) -> ExecutionRespon } } +pub async fn execute_tasks_in_parallel( + tasks: Vec, + config: Config, + notifier: mpsc::Sender, +) -> ExecutionResponse { + let dashboard = Arc::new(TaskDashboard::new( + tasks.clone(), + DisplayMode::Dashboard, + notifier, + )); + let start_time = Instant::now(); + let task_count = tasks.len(); + + if task_count == 0 { + return create_empty_response(); + } + + dashboard.refresh_display().await; + + let (task_tx, task_rx, result_tx, mut result_rx) = create_channels(task_count); + + if let Err(e) = send_tasks_to_channel(tasks, task_tx).await { + eprintln!("Execution failed: {}", e); + return create_error_response(e); + } + + let shared_state = create_shared_state(task_rx, result_tx, dashboard.clone()); + + // Simple static worker allocation - no dynamic scaling needed + let worker_count = std::cmp::min(task_count, config.max_workers); + let mut worker_handles = Vec::new(); + for i in 0..worker_count { + let handle = spawn_worker(shared_state.clone(), i, config.timeout_seconds); + worker_handles.push(handle); + } + + let results = collect_results(&mut result_rx, dashboard.clone(), task_count).await; + + for handle in worker_handles { + if let Err(e) = handle.await { + eprintln!("Worker error: {}", e); + } + } + + dashboard.show_final_summary().await; + + let execution_time = start_time.elapsed().as_millis(); + let stats = calculate_stats(&results, execution_time); + + ExecutionResponse { + status: EXECUTION_STATUS_COMPLETED.to_string(), + results, + stats, + } +} + fn calculate_stats(results: &[TaskResult], execution_time_ms: u128) -> ExecutionStats { let completed = results .iter() @@ -44,30 +109,6 @@ fn calculate_stats(results: &[TaskResult], execution_time_ms: u128) -> Execution } } -struct ExecutionContext { - tasks: Vec, - config: Config, - dashboard: Arc, - start_time: Instant, -} - -impl ExecutionContext { - fn new(tasks: Vec, config: Config) -> Self { - let dashboard = Arc::new(TaskDashboard::new(tasks.clone())); - - Self { - tasks, - config, - dashboard, - start_time: Instant::now(), - } - } - - fn task_count(&self) -> usize { - self.tasks.len() - } -} - fn create_channels( task_count: usize, ) -> ( @@ -90,7 +131,7 @@ fn create_shared_state( task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), result_sender: result_tx, active_workers: Arc::new(AtomicUsize::new(0)), - dashboard: Some(dashboard), + dashboard, }) } @@ -138,50 +179,6 @@ async fn collect_results( results } -async fn execute_with_context(ctx: ExecutionContext) -> Result { - let task_count = ctx.task_count(); - - if task_count == 0 { - return Ok(create_empty_response()); - } - - ctx.dashboard.refresh_display().await; - - let (task_tx, task_rx, result_tx, mut result_rx) = create_channels(task_count); - - send_tasks_to_channel(ctx.tasks, task_tx).await?; - - let shared_state = create_shared_state(task_rx, result_tx, ctx.dashboard.clone()); - - // Simple static worker allocation - no dynamic scaling needed - let worker_count = std::cmp::min(task_count, ctx.config.max_workers); - let mut worker_handles = Vec::new(); - for i in 0..worker_count { - let handle = spawn_worker(shared_state.clone(), i, ctx.config.timeout_seconds); - worker_handles.push(handle); - } - - let results = collect_results(&mut result_rx, ctx.dashboard.clone(), task_count).await; - - // Wait for all workers to finish - for handle in worker_handles { - if let Err(e) = handle.await { - eprintln!("Worker error: {}", e); - } - } - - ctx.dashboard.show_final_summary().await; - - let execution_time = ctx.start_time.elapsed().as_millis(); - let stats = calculate_stats(&results, execution_time); - - Ok(ExecutionResponse { - status: EXECUTION_STATUS_COMPLETED.to_string(), - results, - stats, - }) -} - fn create_error_response(_error: String) -> ExecutionResponse { ExecutionResponse { status: "failed".to_string(), @@ -194,15 +191,3 @@ fn create_error_response(_error: String) -> ExecutionResponse { }, } } - -pub async fn parallel_execute(tasks: Vec, config: Config) -> ExecutionResponse { - let ctx = ExecutionContext::new(tasks, config); - - match execute_with_context(ctx).await { - Ok(response) => response, - Err(e) => { - eprintln!("Execution failed: {}", e); - create_error_response(e) - } - } -} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs index 6c4718c2b3a0..5973025a1786 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs @@ -1,12 +1,19 @@ -use crate::agents::sub_recipe_execution_tool::executor::execute_single_task; -pub use crate::agents::sub_recipe_execution_tool::executor::parallel_execute; +use crate::agents::sub_recipe_execution_tool::executor::{ + execute_single_task, execute_tasks_in_parallel, +}; pub use crate::agents::sub_recipe_execution_tool::types::{ Config, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; +use mcp_core::protocol::JsonRpcMessage; use serde_json::Value; +use tokio::sync::mpsc; -pub async fn execute_tasks(input: Value, execution_mode: &str) -> Result { +pub async fn execute_tasks( + input: Value, + execution_mode: &str, + notifier: mpsc::Sender, +) -> Result { let tasks: Vec = serde_json::from_value(input.get("tasks").ok_or("Missing tasks field")?.clone()) .map_err(|e| format!("Failed to parse tasks: {}", e))?; @@ -17,11 +24,12 @@ pub async fn execute_tasks(input: Value, execution_mode: &str) -> Result { if task_count == 1 { - let response = execute_single_task(&tasks[0], config).await; + let response = execute_single_task(&tasks[0], config, notifier).await; serde_json::to_value(response) .map_err(|e| format!("Failed to serialize response: {}", e)) } else { @@ -29,7 +37,7 @@ pub async fn execute_tasks(input: Value, execution_mode: &str) -> Result { - let response = parallel_execute(tasks, config).await; + let response = execute_tasks_in_parallel(tasks, config, notifier).await; serde_json::to_value(response) .map_err(|e| format!("Failed to serialize response: {}", e)) } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 0e7e061fd4d1..870350a55da1 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -4,6 +4,9 @@ use serde_json::Value; use crate::agents::{ sub_recipe_execution_tool::lib::execute_tasks, tool_execution::ToolCallResult, }; +use mcp_core::protocol::JsonRpcMessage; +use tokio::sync::mpsc; +use tokio_stream; pub const SUB_RECIPE_EXECUTE_TASK_TOOL_NAME: &str = "sub_recipe__execute_task"; pub fn create_sub_recipe_execute_task_tool() -> Tool { @@ -119,18 +122,31 @@ Pre-created Task Based: } pub async fn run_tasks(execute_data: Value) -> ToolCallResult { - let execute_data_clone = execute_data.clone(); - let default_execution_mode_value = Value::String("sequential".to_string()); - let execution_mode = execute_data_clone - .get("execution_mode") - .unwrap_or(&default_execution_mode_value) - .as_str() - .unwrap_or("sequential"); - match execute_tasks(execute_data, execution_mode).await { - Ok(result) => { - let output = serde_json::to_string(&result).unwrap(); - ToolCallResult::from(Ok(vec![Content::text(output)])) + let (notification_tx, notification_rx) = mpsc::channel::(100); + + let result_future = async move { + let execute_data_clone = execute_data.clone(); + let default_execution_mode_value = Value::String("sequential".to_string()); + let execution_mode = execute_data_clone + .get("execution_mode") + .unwrap_or(&default_execution_mode_value) + .as_str() + .unwrap_or("sequential"); + + match execute_tasks(execute_data, execution_mode, notification_tx).await { + Ok(result) => { + let output = serde_json::to_string(&result).unwrap(); + Ok(vec![Content::text(output)]) + } + Err(e) => Err(ToolError::ExecutionError(e.to_string())), } - Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))), + }; + + // Convert receiver to stream + let notification_stream = tokio_stream::wrappers::ReceiverStream::new(notification_rx); + + ToolCallResult { + result: Box::new(Box::pin(result_future)), + notification_stream: Some(Box::new(notification_stream)), } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index ceac281937e2..969d885cedea 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -9,19 +9,15 @@ use tokio::time::timeout; use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult, TaskStatus}; -pub async fn process_task(task: &Task, timeout_seconds: u64) -> TaskResult { - process_task_with_dashboard(task, timeout_seconds, None).await -} - -pub async fn process_task_with_dashboard( +pub async fn process_task( task: &Task, timeout_seconds: u64, - dashboard: Option>, + dashboard: Arc, ) -> TaskResult { let task_clone = task.clone(); let timeout_duration = Duration::from_secs(timeout_seconds); - match timeout(timeout_duration, execute_task(task_clone, dashboard)).await { + match timeout(timeout_duration, get_task_result(task_clone, dashboard)).await { Ok(Ok(data)) => TaskResult { task_id: task.id.clone(), status: TaskStatus::Completed, @@ -43,7 +39,7 @@ pub async fn process_task_with_dashboard( } } -async fn execute_task(task: Task, dashboard: Option>) -> Result { +async fn get_task_result(task: Task, dashboard: Arc) -> Result { let (command, output_identifier) = build_command(&task)?; let (stdout_output, stderr_output, success) = run_command(command, &output_identifier, &task.id, dashboard).await?; @@ -97,7 +93,7 @@ async fn run_command( mut command: Command, output_identifier: &str, task_id: &str, - dashboard: Option>, + dashboard: Arc, ) -> Result<(String, String, bool), String> { let mut child = command .spawn() @@ -108,7 +104,8 @@ async fn run_command( let stdout_task = spawn_output_reader(stdout, output_identifier, false, task_id, dashboard.clone()); - let stderr_task = spawn_output_reader(stderr, output_identifier, true, task_id, None); + let stderr_task = + spawn_output_reader(stderr, output_identifier, true, task_id, dashboard.clone()); let status = child .wait() @@ -126,7 +123,7 @@ fn spawn_output_reader( output_identifier: &str, is_stderr: bool, task_id: &str, - dashboard: Option>, + dashboard: Arc, ) -> tokio::task::JoinHandle { let output_identifier = output_identifier.to_string(); let task_id = task_id.to_string(); @@ -138,9 +135,8 @@ fn spawn_output_reader( buffer.push('\n'); if !is_stderr { - if let Some(dashboard) = &dashboard { - dashboard.update_task_output(&task_id, &buffer).await; - } + // Use dashboard's smart output handling based on display mode + dashboard.send_live_output(&task_id, &line).await; } else { eprintln!("[stderr for {}] {}", output_identifier, line); } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index ed53c7ea08e3..f93c4f5c0c9a 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -93,7 +93,7 @@ pub struct SharedState { pub task_receiver: Arc>>, pub result_sender: mpsc::Sender, pub active_workers: Arc, - pub dashboard: Option>, + pub dashboard: Arc, } impl SharedState { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index 114febb71ff3..7b684126c383 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -66,7 +66,6 @@ pub fn strip_ansi_codes(text: &str) -> String { result } - pub fn get_status_icon(status: &TaskStatus) -> &'static str { match status { TaskStatus::Pending => "⏳", diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index 73ea8b5ab372..ea11a7e604cb 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -1,66 +1,12 @@ -use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; -use crate::agents::sub_recipe_execution_tool::tasks::{process_task, process_task_with_dashboard}; -use crate::agents::sub_recipe_execution_tool::types::{SharedState, Task, TaskResult}; +use crate::agents::sub_recipe_execution_tool::tasks::process_task; +use crate::agents::sub_recipe_execution_tool::types::{SharedState, Task}; use std::sync::Arc; -#[cfg(test)] -mod tests { - use std::sync::atomic::AtomicUsize; - - use tokio::sync::mpsc; - - use super::*; - use crate::agents::sub_recipe_execution_tool::types::Task; - - #[tokio::test] - async fn test_spawn_worker_returns_handle() { - // Create a simple shared state for testing - let (task_tx, task_rx) = mpsc::channel::(1); - let (result_tx, _result_rx) = mpsc::channel::(1); - - let shared_state = Arc::new(SharedState { - task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), - result_sender: result_tx, - active_workers: Arc::new(AtomicUsize::new(0)), - dashboard: None, - }); - - // Test that spawn_worker returns a JoinHandle - let handle = spawn_worker(shared_state.clone(), 0, 5); - - // Verify it's a JoinHandle by checking we can abort it - assert!(!handle.is_finished()); - - // Signal stop and close the channel to let the worker exit - drop(task_tx); // Close the channel - - // Wait for the worker to finish - let result = handle.await; - assert!(result.is_ok()); - } -} - async fn receive_task(state: &SharedState) -> Option { let mut receiver = state.task_receiver.lock().await; receiver.recv().await } -async fn execute_task( - task: Task, - timeout: u64, - dashboard: Option>, -) -> TaskResult { - if let Some(dashboard) = &dashboard { - dashboard.start_task(&task.id).await; - } - - if let Some(dashboard) = dashboard { - process_task_with_dashboard(&task, timeout, Some(dashboard)).await - } else { - process_task(&task, timeout).await - } -} - pub fn spawn_worker( state: Arc, worker_id: usize, @@ -75,7 +21,8 @@ pub fn spawn_worker( async fn worker_loop(state: Arc, _worker_id: usize, timeout_seconds: u64) { while let Some(task) = receive_task(&state).await { - let result = execute_task(task, timeout_seconds, state.dashboard.clone()).await; + state.dashboard.start_task(&task.id).await; + let result = process_task(&task, timeout_seconds, state.dashboard.clone()).await; if let Err(e) = state.result_sender.send(result).await { eprintln!("Worker failed to send result: {}", e); diff --git a/crates/goose/src/agents/sub_recipe_manager.rs b/crates/goose/src/agents/sub_recipe_manager.rs index d759914a9d9e..cb01c3ffe4dc 100644 --- a/crates/goose/src/agents/sub_recipe_manager.rs +++ b/crates/goose/src/agents/sub_recipe_manager.rs @@ -61,31 +61,6 @@ impl SubRecipeManager { } } - // async fn call_sub_recipe_tool( - // &self, - // tool_name: &str, - // params: Value, - // ) -> Result, ToolError> { - // let sub_recipe = self.sub_recipes.get(tool_name).ok_or_else(|| { - // let sub_recipe_name = tool_name - // .strip_prefix(SUB_RECIPE_TOOL_NAME_PREFIX) - // .and_then(|s| s.strip_prefix("_")) - // .ok_or_else(|| { - // ToolError::InvalidParameters(format!( - // "Invalid sub-recipe tool name format: {}", - // tool_name - // )) - // }) - // .unwrap(); - - // ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name)) - // })?; - - // let output = run_sub_recipe(sub_recipe, params).await.map_err(|e| { - // ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e)) - // })?; - // Ok(vec![Content::text(output)]) - // } async fn call_sub_recipe_tool( &self, tool_name: &str, From 4d11f373fd2bc889abf776d25a8655573b07687b Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 03:30:00 +1000 Subject: [PATCH 11/34] refactored the files --- crates/goose-cli/src/session/mod.rs | 39 ++-- .../src/session/task_execution_display.rs | 174 +++++++++++++++++ .../sub_recipe_execution_tool/executor.rs | 28 +-- .../agents/sub_recipe_execution_tool/mod.rs | 2 +- ...dashboard.rs => task_execution_tracker.rs} | 183 +++++++----------- .../agents/sub_recipe_execution_tool/tasks.rs | 51 +++-- .../agents/sub_recipe_execution_tool/types.rs | 4 +- .../sub_recipe_execution_tool/workers.rs | 5 +- 8 files changed, 325 insertions(+), 161 deletions(-) create mode 100644 crates/goose-cli/src/session/task_execution_display.rs rename crates/goose/src/agents/sub_recipe_execution_tool/{dashboard.rs => task_execution_tracker.rs} (50%) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 6efcf16187ee..920267351331 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -4,8 +4,11 @@ mod export; mod input; mod output; mod prompt; +mod task_execution_display; mod thinking; +use crate::session::task_execution_display::TASK_EXECUTION_NOTIFICATION_TYPE; + pub use self::export::message_to_markdown; pub use builder::{build_session, SessionBuilderConfig, SessionSettings}; use console::Color; @@ -16,6 +19,7 @@ use goose::permission::PermissionConfirmation; use goose::providers::base::Provider; pub use goose::session::Identifier; use std::io::Write; +use task_execution_display::format_task_execution_notification; use anyhow::{Context, Result}; use completion::GooseCompleter; @@ -1011,12 +1015,8 @@ impl Session { } }; (formatted, Some(subagent_id.to_string()), Some(notification_type.to_string())) - } else if let Some(Value::String(output)) = o.get("output") { - // Shell tool notification - (output.to_owned(), None, None) - } else if let Some(Value::String(display)) = o.get("display") { - // Dashboard notification - return raw display content with ANSI codes - (display.to_owned(), None, Some("dashboard".to_string())) + } else if let Some(result) = format_task_execution_notification(data) { + result } else { (data.to_string(), None, None) } @@ -1026,9 +1026,17 @@ impl Session { }, }; - // Handle dashboard notifications specially - print raw content with ANSI codes - if let Some(ref notification_type) = _notification_type { - if notification_type == "dashboard" { + // Handle subagent notifications - show immediately + if let Some(_id) = subagent_id { + // Show subagent notifications immediately (no buffering) with compact spacing + if interactive { + let _ = progress_bars.hide(); + println!("{}", console::style(&formatted_message).green().dim()); + } else { + progress_bars.log(&formatted_message); + } + } else if let Some(ref notification_type) = _notification_type { + if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE { if interactive { let _ = progress_bars.hide(); print!("{}", formatted_message); @@ -1037,20 +1045,9 @@ impl Session { print!("{}", formatted_message); std::io::stdout().flush().unwrap(); } - continue; // Skip the normal notification handling below } } - - // Handle subagent notifications - show immediately - if let Some(_id) = subagent_id { - // Show subagent notifications immediately (no buffering) with compact spacing - if interactive { - let _ = progress_bars.hide(); - println!("{}", console::style(&formatted_message).green().dim()); - } else { - progress_bars.log(&formatted_message); - } - } else { + else { // Non-subagent notification, display immediately with compact spacing if interactive { let _ = progress_bars.hide(); diff --git a/crates/goose-cli/src/session/task_execution_display.rs b/crates/goose-cli/src/session/task_execution_display.rs new file mode 100644 index 000000000000..ba57c00dedd6 --- /dev/null +++ b/crates/goose-cli/src/session/task_execution_display.rs @@ -0,0 +1,174 @@ +use serde_json::Value; + +const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H"; +const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H"; +const CLEAR_TO_EOL: &str = "\x1b[K"; +const CLEAR_BELOW: &str = "\x1b[J"; +pub const TASK_EXECUTION_NOTIFICATION_TYPE: &str = "task_execution"; + +pub fn format_tasks_update(data: &Value) -> String { + let mut display = String::new(); + + // Determine if this is initial display or update + static mut INITIAL_SHOWN: bool = false; + unsafe { + if !INITIAL_SHOWN { + display.push_str(CLEAR_SCREEN); + display.push_str("🎯 Task Execution Dashboard\n"); + display.push_str("═══════════════════════════\n\n"); + INITIAL_SHOWN = true; + } else { + display.push_str(MOVE_TO_PROGRESS_LINE); + } + } + + if let Some(stats) = data.get("stats") { + let total = stats.get("total").and_then(|v| v.as_u64()).unwrap_or(0); + let pending = stats.get("pending").and_then(|v| v.as_u64()).unwrap_or(0); + let running = stats.get("running").and_then(|v| v.as_u64()).unwrap_or(0); + let completed = stats.get("completed").and_then(|v| v.as_u64()).unwrap_or(0); + let failed = stats.get("failed").and_then(|v| v.as_u64()).unwrap_or(0); + + display.push_str(&format!( + "📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed", + total, pending, running, completed, failed + )); + display.push_str(&format!("{}\n\n", CLEAR_TO_EOL)); + } + + if let Some(tasks) = data.get("tasks").and_then(|t| t.as_array()) { + for task in tasks { + let id = task.get("id").and_then(|v| v.as_str()).unwrap_or("unknown"); + let status = task + .get("status") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + let task_type = task + .get("task_type") + .and_then(|v| v.as_str()) + .unwrap_or("task"); + + let status_icon = match status { + "Pending" => "⏳", + "Running" => "🏃", + "Completed" => "✅", + "Failed" => "❌", + _ => "◯", + }; + + display.push_str(&format!( + "{} {} ({}): {}\n", + status_icon, id, task_type, status + )); + + if status == "Running" { + if let Some(output) = task.get("current_output").and_then(|v| v.as_str()) { + if !output.trim().is_empty() { + let lines: Vec<&str> = output.lines().collect(); + if lines.len() > 3 { + display.push_str(" ...\n"); + for line in lines.iter().rev().take(3).rev() { + display.push_str(&format!(" {}\n", line)); + } + } else { + for line in lines { + display.push_str(&format!(" {}\n", line)); + } + } + } + } + } + + if status == "Failed" { + if let Some(error) = task.get("error").and_then(|v| v.as_str()) { + display.push_str(&format!(" Error: {}\n", error)); + } + } + } + } + + display.push_str(CLEAR_BELOW); + display +} + +pub fn format_tasks_complete(data: &Value) -> String { + let mut summary = String::new(); + summary.push_str("Execution Complete!\n"); + summary.push_str("═══════════════════════\n"); + + if let Some(stats) = data.get("stats") { + let total = stats.get("total").and_then(|v| v.as_u64()).unwrap_or(0); + let completed = stats.get("completed").and_then(|v| v.as_u64()).unwrap_or(0); + let failed = stats.get("failed").and_then(|v| v.as_u64()).unwrap_or(0); + let success_rate = stats + .get("success_rate") + .and_then(|v| v.as_f64()) + .unwrap_or(0.0); + + summary.push_str(&format!("Total Tasks: {}\n", total)); + summary.push_str(&format!("✅ Completed: {}\n", completed)); + summary.push_str(&format!("❌ Failed: {}\n", failed)); + summary.push_str(&format!("📈 Success Rate: {:.1}%\n", success_rate)); + } + + if let Some(failed_tasks) = data.get("failed_tasks").and_then(|t| t.as_array()) { + if !failed_tasks.is_empty() { + summary.push_str("\n❌ Failed Tasks:\n"); + for task in failed_tasks { + let name = task + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("Unknown"); + summary.push_str(&format!(" • {}\n", name)); + if let Some(error) = task.get("error").and_then(|v| v.as_str()) { + summary.push_str(&format!(" Error: {}\n", error)); + } + } + } + } + + summary.push_str("\n📝 Generating summary...\n"); + summary +} + +pub fn format_task_execution_notification( + data: &Value, +) -> Option<(String, Option, Option)> { + if let Value::Object(o) = data { + if o.get("type").and_then(|t| t.as_str()) == Some(TASK_EXECUTION_NOTIFICATION_TYPE) { + return Some(match o.get("subtype").and_then(|t| t.as_str()) { + Some("line_output") => { + if let Some(Value::String(line_output)) = o.get("output") { + ( + format!("{}\n", line_output), + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ) + } else { + (data.to_string(), None, None) + } + } + Some("tasks_update") => { + let data_value = Value::Object(o.clone()); + let formatted_display = format_tasks_update(&data_value); + ( + formatted_display, + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ) + } + Some("tasks_complete") => { + let data_value = Value::Object(o.clone()); + let formatted_summary = format_tasks_complete(&data_value); + ( + formatted_summary, + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ) + } + _ => (data.to_string(), None, None), + }); + } + } + None +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs index 3e1974323255..b1dbecc84b03 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs @@ -4,10 +4,12 @@ use std::sync::Arc; use tokio::sync::mpsc; use tokio::time::Instant; -use crate::agents::sub_recipe_execution_tool::dashboard::{DisplayMode, TaskDashboard}; use crate::agents::sub_recipe_execution_tool::lib::{ Config, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; +use crate::agents::sub_recipe_execution_tool::task_execution_tracker::{ + DisplayMode, TaskExecutionTracker, +}; use crate::agents::sub_recipe_execution_tool::tasks::process_task; use crate::agents::sub_recipe_execution_tool::workers::spawn_worker; @@ -19,12 +21,12 @@ pub async fn execute_single_task( notifier: mpsc::Sender, ) -> ExecutionResponse { let start_time = Instant::now(); - let dashboard = Arc::new(TaskDashboard::new( + let task_execution_tracker = Arc::new(TaskExecutionTracker::new( vec![task.clone()], DisplayMode::SingleTaskOutput, notifier, )); - let result = process_task(task, config.timeout_seconds, dashboard).await; + let result = process_task(task, config.timeout_seconds, task_execution_tracker).await; let execution_time = start_time.elapsed().as_millis(); let stats = calculate_stats(&[result.clone()], execution_time); @@ -40,9 +42,9 @@ pub async fn execute_tasks_in_parallel( config: Config, notifier: mpsc::Sender, ) -> ExecutionResponse { - let dashboard = Arc::new(TaskDashboard::new( + let task_execution_tracker = Arc::new(TaskExecutionTracker::new( tasks.clone(), - DisplayMode::Dashboard, + DisplayMode::MultipleTasksOutput, notifier, )); let start_time = Instant::now(); @@ -52,7 +54,7 @@ pub async fn execute_tasks_in_parallel( return create_empty_response(); } - dashboard.refresh_display().await; + task_execution_tracker.refresh_display().await; let (task_tx, task_rx, result_tx, mut result_rx) = create_channels(task_count); @@ -61,7 +63,7 @@ pub async fn execute_tasks_in_parallel( return create_error_response(e); } - let shared_state = create_shared_state(task_rx, result_tx, dashboard.clone()); + let shared_state = create_shared_state(task_rx, result_tx, task_execution_tracker.clone()); // Simple static worker allocation - no dynamic scaling needed let worker_count = std::cmp::min(task_count, config.max_workers); @@ -71,7 +73,7 @@ pub async fn execute_tasks_in_parallel( worker_handles.push(handle); } - let results = collect_results(&mut result_rx, dashboard.clone(), task_count).await; + let results = collect_results(&mut result_rx, task_execution_tracker.clone(), task_count).await; for handle in worker_handles { if let Err(e) = handle.await { @@ -79,7 +81,7 @@ pub async fn execute_tasks_in_parallel( } } - dashboard.show_final_summary().await; + task_execution_tracker.send_tasks_complete().await; let execution_time = start_time.elapsed().as_millis(); let stats = calculate_stats(&results, execution_time); @@ -125,13 +127,13 @@ fn create_channels( fn create_shared_state( task_rx: mpsc::Receiver, result_tx: mpsc::Sender, - dashboard: Arc, + task_execution_tracker: Arc, ) -> Arc { Arc::new(SharedState { task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), result_sender: result_tx, active_workers: Arc::new(AtomicUsize::new(0)), - dashboard, + task_execution_tracker, }) } @@ -163,12 +165,12 @@ fn create_empty_response() -> ExecutionResponse { async fn collect_results( result_rx: &mut mpsc::Receiver, - dashboard: Arc, + task_execution_tracker: Arc, expected_count: usize, ) -> Vec { let mut results = Vec::new(); while let Some(result) = result_rx.recv().await { - dashboard + task_execution_tracker .complete_task(&result.task_id, result.clone()) .await; results.push(result); diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs index 6f131862fefd..03568dc53197 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs @@ -1,7 +1,7 @@ -mod dashboard; mod executor; pub mod lib; pub mod sub_recipe_execute_task_tool; +mod task_execution_tracker; mod tasks; mod types; pub mod utils; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs similarity index 50% rename from crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs rename to crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs index 572c82446073..5844f19640af 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/dashboard.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -6,32 +6,24 @@ use tokio::sync::{mpsc, RwLock}; use tokio::time::{sleep, Duration, Instant}; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskResult, TaskStatus}; -use crate::agents::sub_recipe_execution_tool::utils::{ - count_by_status, format_task_display, get_task_name, -}; +use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; #[derive(Debug, Clone, PartialEq)] pub enum DisplayMode { - Dashboard, + MultipleTasksOutput, SingleTaskOutput, } const THROTTLE_INTERVAL_MS: u64 = 1000; -const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H"; -const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H"; -const CLEAR_TO_EOL: &str = "\x1b[K"; -const CLEAR_BELOW: &str = "\x1b[J"; -pub struct TaskDashboard { +pub struct TaskExecutionTracker { tasks: Arc>>, - last_display: Arc>, last_refresh: Arc>, - initial_display_shown: Arc>, notifier: mpsc::Sender, display_mode: DisplayMode, } -impl TaskDashboard { +impl TaskExecutionTracker { pub fn new( tasks: Vec, display_mode: DisplayMode, @@ -57,9 +49,7 @@ impl TaskDashboard { Self { tasks: Arc::new(RwLock::new(task_map)), - last_display: Arc::new(RwLock::new(String::new())), last_refresh: Arc::new(RwLock::new(Instant::now())), - initial_display_shown: Arc::new(RwLock::new(false)), notifier, display_mode, } @@ -89,6 +79,7 @@ impl TaskDashboard { pub async fn send_live_output(&self, task_id: &str, line: &str) { match self.display_mode { DisplayMode::SingleTaskOutput => { + // Send raw output data - let subscriber format it let _ = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { @@ -96,13 +87,15 @@ impl TaskDashboard { method: "notifications/message".to_string(), params: Some(json!({ "data": { - "type": "dashboard", - "display": format!("{}\n", line) + "type": "task_execution", + "subtype": "line_output", + "task_id": task_id, + "output": line } })), })); } - DisplayMode::Dashboard => { + DisplayMode::MultipleTasksOutput => { let mut tasks = self.tasks.write().await; if let Some(task_info) = tasks.get_mut(task_id) { task_info.current_output.push_str(line); @@ -129,72 +122,53 @@ impl TaskDashboard { } } - fn render_header(&self, display: &mut String, initial_shown: &mut bool) { - if !*initial_shown { - display.push_str(CLEAR_SCREEN); - display.push_str("🎯 Task Execution Dashboard\n"); - display.push_str("═══════════════════════════\n\n"); - *initial_shown = true; - } else { - display.push_str(MOVE_TO_PROGRESS_LINE); - } - } - - fn render_progress_line(&self, display: &mut String, tasks: &HashMap) { - let (total, pending, running, completed, failed) = count_by_status(tasks); - display.push_str(&format!( - "📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed", - total, pending, running, completed, failed - )); - display.push_str(&format!("{}\n\n", CLEAR_TO_EOL)); - } - - fn render_task(&self, display: &mut String, task_info: &TaskInfo) { - let task_display = format_task_display(task_info, Instant::now()); - display.push_str(&task_display); - } + async fn send_tasks_update(&self) { + let tasks = self.tasks.read().await; + let task_list: Vec<_> = tasks.values().collect(); + let (total, pending, running, completed, failed) = count_by_status(&tasks); - async fn update_display_if_changed(&self, display: String) { - let mut last_display = self.last_display.write().await; - if *last_display != display { - let _ = self - .notifier - .try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: "notifications/message".to_string(), - params: Some(json!({ - "data": { - "type": "dashboard", - "display": display.clone() - } - })), - })); - *last_display = display; - } + let _ = self + .notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": { + "type": "task_execution", + "subtype": "tasks_update", + "stats": { + "total": total, + "pending": pending, + "running": running, + "completed": completed, + "failed": failed + }, + "tasks": task_list.iter().map(|task_info| { + let now = Instant::now(); + json!({ + "id": task_info.task.id, + "status": task_info.status, + "duration_secs": task_info.start_time.map(|start| { + if let Some(end) = task_info.end_time { + end.duration_since(start).as_secs_f64() + } else { + now.duration_since(start).as_secs_f64() + } + }), + "current_output": task_info.current_output, + "task_type": task_info.task.task_type, + "error": task_info.error() + }) + }).collect::>() + } + })), + })); } pub async fn refresh_display(&self) { match self.display_mode { - DisplayMode::Dashboard => { - let tasks = self.tasks.read().await; - let mut display = String::new(); - - let mut initial_shown = self.initial_display_shown.write().await; - self.render_header(&mut display, &mut initial_shown); - drop(initial_shown); - - self.render_progress_line(&mut display, &tasks); - - let mut task_list: Vec<_> = tasks.values().collect(); - task_list.sort_by_key(|t| &t.task.id); - - for task_info in task_list { - self.render_task(&mut display, task_info); - } - - display.push_str(CLEAR_BELOW); - - self.update_display_if_changed(display).await; + DisplayMode::MultipleTasksOutput => { + self.send_tasks_update().await; } DisplayMode::SingleTaskOutput => { // No dashboard display needed for single task output mode @@ -203,39 +177,23 @@ impl TaskDashboard { } } - pub async fn show_final_summary(&self) { + pub async fn send_tasks_complete(&self) { let tasks = self.tasks.read().await; - - let mut summary = String::new(); - summary.push_str("Execution Complete!\n"); - summary.push_str("═══════════════════════\n"); - let (total, _, _, completed, failed) = count_by_status(&tasks); - summary.push_str(&format!("Total Tasks: {}\n", total)); - summary.push_str(&format!("✅ Completed: {}\n", completed)); - summary.push_str(&format!("❌ Failed: {}\n", failed)); - summary.push_str(&format!( - "📈 Success Rate: {:.1}%\n", - (completed as f64 / total as f64) * 100.0 - )); - - if failed > 0 { - summary.push_str("\n❌ Failed Tasks:\n"); - for task_info in tasks.values() { - if matches!(task_info.status, TaskStatus::Failed) { - let task_name = get_task_name(task_info); - summary.push_str(&format!(" • {}\n", task_name)); - if let Some(error) = task_info.error() { - summary.push_str(&format!(" Error: {}\n", error)); - } - } - } - } - - summary.push_str("\n📝 Generating summary...\n"); + // Send structured summary data only + let failed_tasks: Vec<_> = tasks + .values() + .filter(|task_info| matches!(task_info.status, TaskStatus::Failed)) + .map(|task_info| { + json!({ + "id": task_info.task.id, + "name": get_task_name(task_info), + "error": task_info.error() + }) + }) + .collect(); - // Send the final summary via notification let _ = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { @@ -243,8 +201,15 @@ impl TaskDashboard { method: "notifications/message".to_string(), params: Some(json!({ "data": { - "type": "dashboard", - "display": summary + "type": "task_execution", + "subtype": "tasks_complete", + "stats": { + "total": total, + "completed": completed, + "failed": failed, + "success_rate": (completed as f64 / total as f64) * 100.0 + }, + "failed_tasks": failed_tasks } })), })); diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index 969d885cedea..d59eac7cfe3b 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -6,18 +6,23 @@ use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; use tokio::time::timeout; -use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; +use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult, TaskStatus}; pub async fn process_task( task: &Task, timeout_seconds: u64, - dashboard: Arc, + task_execution_tracker: Arc, ) -> TaskResult { let task_clone = task.clone(); let timeout_duration = Duration::from_secs(timeout_seconds); - match timeout(timeout_duration, get_task_result(task_clone, dashboard)).await { + match timeout( + timeout_duration, + get_task_result(task_clone, task_execution_tracker), + ) + .await + { Ok(Ok(data)) => TaskResult { task_id: task.id.clone(), status: TaskStatus::Completed, @@ -39,10 +44,18 @@ pub async fn process_task( } } -async fn get_task_result(task: Task, dashboard: Arc) -> Result { +async fn get_task_result( + task: Task, + task_execution_tracker: Arc, +) -> Result { let (command, output_identifier) = build_command(&task)?; - let (stdout_output, stderr_output, success) = - run_command(command, &output_identifier, &task.id, dashboard).await?; + let (stdout_output, stderr_output, success) = run_command( + command, + &output_identifier, + &task.id, + task_execution_tracker, + ) + .await?; if success { process_output(stdout_output) @@ -93,7 +106,7 @@ async fn run_command( mut command: Command, output_identifier: &str, task_id: &str, - dashboard: Arc, + task_execution_tracker: Arc, ) -> Result<(String, String, bool), String> { let mut child = command .spawn() @@ -102,10 +115,20 @@ async fn run_command( let stdout = child.stdout.take().expect("Failed to capture stdout"); let stderr = child.stderr.take().expect("Failed to capture stderr"); - let stdout_task = - spawn_output_reader(stdout, output_identifier, false, task_id, dashboard.clone()); - let stderr_task = - spawn_output_reader(stderr, output_identifier, true, task_id, dashboard.clone()); + let stdout_task = spawn_output_reader( + stdout, + output_identifier, + false, + task_id, + task_execution_tracker.clone(), + ); + let stderr_task = spawn_output_reader( + stderr, + output_identifier, + true, + task_id, + task_execution_tracker.clone(), + ); let status = child .wait() @@ -123,7 +146,7 @@ fn spawn_output_reader( output_identifier: &str, is_stderr: bool, task_id: &str, - dashboard: Arc, + task_execution_tracker: Arc, ) -> tokio::task::JoinHandle { let output_identifier = output_identifier.to_string(); let task_id = task_id.to_string(); @@ -136,7 +159,9 @@ fn spawn_output_reader( if !is_stderr { // Use dashboard's smart output handling based on display mode - dashboard.send_live_output(&task_id, &line).await; + task_execution_tracker + .send_live_output(&task_id, &line) + .await; } else { eprintln!("[stderr for {}] {}", output_identifier, line); } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index f93c4f5c0c9a..62f37028efd5 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -4,7 +4,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::sync::mpsc; -use crate::agents::sub_recipe_execution_tool::dashboard::TaskDashboard; +use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Task { @@ -93,7 +93,7 @@ pub struct SharedState { pub task_receiver: Arc>>, pub result_sender: mpsc::Sender, pub active_workers: Arc, - pub dashboard: Arc, + pub task_execution_tracker: Arc, } impl SharedState { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index ea11a7e604cb..9e34aadda826 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -21,8 +21,9 @@ pub fn spawn_worker( async fn worker_loop(state: Arc, _worker_id: usize, timeout_seconds: u64) { while let Some(task) = receive_task(&state).await { - state.dashboard.start_task(&task.id).await; - let result = process_task(&task, timeout_seconds, state.dashboard.clone()).await; + state.task_execution_tracker.start_task(&task.id).await; + let result = + process_task(&task, timeout_seconds, state.task_execution_tracker.clone()).await; if let Err(e) = state.result_sender.send(result).await { eprintln!("Worker failed to send result: {}", e); From d8fdedd9c0e02d12def90963c1f6ad1da2b22eec Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 07:00:20 +1000 Subject: [PATCH 12/34] better console ui and run sub recipe in no-session mode --- .../src/session/task_execution_display.rs | 160 +++++++++++++----- .../task_execution_tracker.rs | 25 +++ .../agents/sub_recipe_execution_tool/tasks.rs | 2 +- 3 files changed, 142 insertions(+), 45 deletions(-) diff --git a/crates/goose-cli/src/session/task_execution_display.rs b/crates/goose-cli/src/session/task_execution_display.rs index ba57c00dedd6..38542ba2c29c 100644 --- a/crates/goose-cli/src/session/task_execution_display.rs +++ b/crates/goose-cli/src/session/task_execution_display.rs @@ -9,7 +9,6 @@ pub const TASK_EXECUTION_NOTIFICATION_TYPE: &str = "task_execution"; pub fn format_tasks_update(data: &Value) -> String { let mut display = String::new(); - // Determine if this is initial display or update static mut INITIAL_SHOWN: bool = false; unsafe { if !INITIAL_SHOWN { @@ -37,58 +36,131 @@ pub fn format_tasks_update(data: &Value) -> String { } if let Some(tasks) = data.get("tasks").and_then(|t| t.as_array()) { - for task in tasks { - let id = task.get("id").and_then(|v| v.as_str()).unwrap_or("unknown"); - let status = task - .get("status") - .and_then(|v| v.as_str()) - .unwrap_or("unknown"); - let task_type = task - .get("task_type") - .and_then(|v| v.as_str()) - .unwrap_or("task"); - - let status_icon = match status { - "Pending" => "⏳", - "Running" => "🏃", - "Completed" => "✅", - "Failed" => "❌", - _ => "◯", - }; - - display.push_str(&format!( - "{} {} ({}): {}\n", - status_icon, id, task_type, status - )); + let mut sorted_tasks: Vec<_> = tasks.iter().collect(); + sorted_tasks.sort_by_key(|task| task.get("id").and_then(|v| v.as_str()).unwrap_or("")); - if status == "Running" { - if let Some(output) = task.get("current_output").and_then(|v| v.as_str()) { - if !output.trim().is_empty() { - let lines: Vec<&str> = output.lines().collect(); - if lines.len() > 3 { - display.push_str(" ...\n"); - for line in lines.iter().rev().take(3).rev() { - display.push_str(&format!(" {}\n", line)); - } - } else { - for line in lines { - display.push_str(&format!(" {}\n", line)); - } - } - } + for task in sorted_tasks { + display.push_str(&format_task_from_json(task)); + } + } + + display.push_str(CLEAR_BELOW); + display +} + +fn format_task_from_json(task: &Value) -> String { + let mut task_display = String::new(); + + let id = task.get("id").and_then(|v| v.as_str()).unwrap_or("unknown"); + let status = task + .get("status") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + let task_type = task + .get("task_type") + .and_then(|v| v.as_str()) + .unwrap_or("task"); + let task_name = task.get("task_name").and_then(|v| v.as_str()).unwrap_or(id); + let task_metadata = task + .get("task_metadata") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + let status_icon = match status { + "Pending" => "⏳", + "Running" => "🏃", + "Completed" => "✅", + "Failed" => "❌", + _ => "◯", + }; + + task_display.push_str(&format!( + "{} {} ({}){}\n", + status_icon, task_name, task_type, CLEAR_TO_EOL + )); + + if !task_metadata.is_empty() { + task_display.push_str(&format!( + " 📋 Parameters: {}{}\n", + task_metadata, CLEAR_TO_EOL + )); + } + + if let Some(duration_secs) = task.get("duration_secs").and_then(|v| v.as_f64()) { + task_display.push_str(&format!(" ⏱️ {:.1}s{}\n", duration_secs, CLEAR_TO_EOL)); + } + + if status == "Running" { + if let Some(current_output) = task.get("current_output").and_then(|v| v.as_str()) { + if !current_output.trim().is_empty() { + let processed_output = process_output_for_display(current_output); + if !processed_output.is_empty() { + task_display.push_str(&format!(" 💬 {}{}\n", processed_output, CLEAR_TO_EOL)); } } + } + } - if status == "Failed" { - if let Some(error) = task.get("error").and_then(|v| v.as_str()) { - display.push_str(&format!(" Error: {}\n", error)); + if status == "Failed" { + if let Some(error) = task.get("error").and_then(|v| v.as_str()) { + let error_preview = truncate_with_ellipsis(error, 80); + task_display.push_str(&format!( + " ⚠️ {}{}\n", + error_preview.replace('\n', " "), + CLEAR_TO_EOL + )); + } + } + + task_display.push_str(&format!("{}\n", CLEAR_TO_EOL)); + task_display +} + +fn process_output_for_display(output: &str) -> String { + const MAX_OUTPUT_LINES: usize = 2; + const OUTPUT_PREVIEW_LENGTH: usize = 100; + + let lines: Vec<&str> = output.lines().collect(); + let recent_lines = if lines.len() > MAX_OUTPUT_LINES { + &lines[lines.len() - MAX_OUTPUT_LINES..] + } else { + &lines + }; + + let clean_output = recent_lines.join(" | "); + let stripped = strip_ansi_codes(&clean_output); + truncate_with_ellipsis(&stripped, OUTPUT_PREVIEW_LENGTH) +} + +fn truncate_with_ellipsis(text: &str, max_len: usize) -> String { + if text.len() > max_len { + format!("{}...", &text[..max_len.saturating_sub(3)]) + } else { + text.to_string() + } +} + +fn strip_ansi_codes(text: &str) -> String { + let mut result = String::new(); + let mut chars = text.chars(); + + while let Some(ch) = chars.next() { + if ch == '\x1b' { + if chars.next() == Some('[') { + loop { + match chars.next() { + Some(c) if c.is_ascii_alphabetic() => break, + Some(_) => continue, + None => break, + } } } + } else { + result.push(ch); } } - display.push_str(CLEAR_BELOW); - display + result } pub fn format_tasks_complete(data: &Value) -> String { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs index 5844f19640af..59325e18bf75 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -7,6 +7,7 @@ use tokio::time::{sleep, Duration, Instant}; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskResult, TaskStatus}; use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; +use serde_json::Value; #[derive(Debug, Clone, PartialEq)] pub enum DisplayMode { @@ -16,6 +17,28 @@ pub enum DisplayMode { const THROTTLE_INTERVAL_MS: u64 = 1000; +fn format_task_metadata(task_info: &TaskInfo) -> String { + if let Some(params) = task_info.task.get_command_parameters() { + if params.is_empty() { + return String::new(); + } + + params + .iter() + .map(|(key, value)| { + let value_str = match value { + Value::String(s) => s.clone(), + _ => value.to_string(), + }; + format!("{}={}", key, value_str) + }) + .collect::>() + .join(",") + } else { + String::new() + } +} + pub struct TaskExecutionTracker { tasks: Arc>>, last_refresh: Arc>, @@ -157,6 +180,8 @@ impl TaskExecutionTracker { }), "current_output": task_info.current_output, "task_type": task_info.task.task_type, + "task_name": get_task_name(task_info), + "task_metadata": format_task_metadata(task_info), "error": task_info.error() }) }).collect::>() diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index d59eac7cfe3b..b975184da343 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -79,7 +79,7 @@ fn build_command(task: &Task) -> Result<(Command, String), String> { output_identifier = format!("sub-recipe {}", sub_recipe_name); let mut cmd = Command::new("goose"); - cmd.arg("run").arg("--recipe").arg(path); + cmd.arg("run").arg("--recipe").arg(path).arg("--no-session"); for (key, value) in command_parameters { let key_str = key.to_string(); From f1c38f93355a6f2a3a12f8963fa1021e51c6aff8 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 09:45:26 +1000 Subject: [PATCH 13/34] added timeout in task execution --- .../src/agents/recipe_tools/sub_recipe_tools.rs | 4 ++++ .../agents/sub_recipe_execution_tool/executor.rs | 5 ++--- .../src/agents/sub_recipe_execution_tool/lib.rs | 2 +- .../sub_recipe_execute_task_tool.rs | 11 +++++------ .../src/agents/sub_recipe_execution_tool/tasks.rs | 4 ++-- .../src/agents/sub_recipe_execution_tool/types.rs | 8 ++------ .../src/agents/sub_recipe_execution_tool/workers.rs | 13 ++++--------- crates/goose/src/recipe/mod.rs | 1 + 8 files changed, 21 insertions(+), 27 deletions(-) diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 8cd1a8599037..116e4a0a5193 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -57,6 +57,10 @@ pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Re Task { id: uuid::Uuid::new_v4().to_string(), task_type: "sub_recipe".to_string(), + timeout_in_seconds: sub_recipe + .executions + .as_ref() + .and_then(|e| e.timeout_in_seconds), payload, } }) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs index b1dbecc84b03..4ec7443a40b4 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs @@ -17,7 +17,6 @@ const EXECUTION_STATUS_COMPLETED: &str = "completed"; pub async fn execute_single_task( task: &Task, - config: Config, notifier: mpsc::Sender, ) -> ExecutionResponse { let start_time = Instant::now(); @@ -26,7 +25,7 @@ pub async fn execute_single_task( DisplayMode::SingleTaskOutput, notifier, )); - let result = process_task(task, config.timeout_seconds, task_execution_tracker).await; + let result = process_task(task, task_execution_tracker).await; let execution_time = start_time.elapsed().as_millis(); let stats = calculate_stats(&[result.clone()], execution_time); @@ -69,7 +68,7 @@ pub async fn execute_tasks_in_parallel( let worker_count = std::cmp::min(task_count, config.max_workers); let mut worker_handles = Vec::new(); for i in 0..worker_count { - let handle = spawn_worker(shared_state.clone(), i, config.timeout_seconds); + let handle = spawn_worker(shared_state.clone(), i); worker_handles.push(handle); } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs index 5973025a1786..6a18b24907a8 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs @@ -29,7 +29,7 @@ pub async fn execute_tasks( match execution_mode { "sequential" => { if task_count == 1 { - let response = execute_single_task(&tasks[0], config, notifier).await; + let response = execute_single_task(&tasks[0], notifier).await; serde_json::to_value(response) .map_err(|e| format!("Failed to serialize response: {}", e)) } else { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 870350a55da1..40fe117a05ba 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -62,6 +62,10 @@ Pre-created Task Based: "default": "sub_recipe", "description": "the type of task to execute, can be one of: sub_recipe, text_instruction" }, + "timeout_in_seconds": { + "type": "number", + "description": "timeout in seconds for the task." + }, "payload": { "type": "object", "properties": { @@ -97,9 +101,6 @@ Pre-created Task Based: "config": { "type": "object", "properties": { - "timeout_seconds": { - "type": "number" - }, "max_workers": { "type": "number" }, @@ -126,11 +127,9 @@ pub async fn run_tasks(execute_data: Value) -> ToolCallResult { let result_future = async move { let execute_data_clone = execute_data.clone(); - let default_execution_mode_value = Value::String("sequential".to_string()); let execution_mode = execute_data_clone .get("execution_mode") - .unwrap_or(&default_execution_mode_value) - .as_str() + .and_then(|v| v.as_str()) .unwrap_or("sequential"); match execute_tasks(execute_data, execution_mode, notification_tx).await { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index b975184da343..f5762f973855 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -11,11 +11,11 @@ use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult, TaskStat pub async fn process_task( task: &Task, - timeout_seconds: u64, task_execution_tracker: Arc, ) -> TaskResult { + let timeout_in_seconds = task.timeout_in_seconds.unwrap_or(300); let task_clone = task.clone(); - let timeout_duration = Duration::from_secs(timeout_seconds); + let timeout_duration = Duration::from_secs(timeout_in_seconds); match timeout( timeout_duration, diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index 62f37028efd5..4184eb65ec36 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -10,6 +10,7 @@ use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecut pub struct Task { pub id: String, pub task_type: String, + pub timeout_in_seconds: Option, pub payload: Value, } @@ -110,8 +111,6 @@ impl SharedState { pub struct Config { #[serde(default = "default_max_workers")] pub max_workers: usize, - #[serde(default = "default_timeout")] - pub timeout_seconds: u64, #[serde(default = "default_initial_workers")] pub initial_workers: usize, } @@ -120,7 +119,6 @@ impl Default for Config { fn default() -> Self { Self { max_workers: default_max_workers(), - timeout_seconds: default_timeout(), initial_workers: default_initial_workers(), } } @@ -129,9 +127,7 @@ impl Default for Config { fn default_max_workers() -> usize { 10 } -fn default_timeout() -> u64 { - 300 -} + fn default_initial_workers() -> usize { 2 } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index 9e34aadda826..f891595456b5 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -7,23 +7,18 @@ async fn receive_task(state: &SharedState) -> Option { receiver.recv().await } -pub fn spawn_worker( - state: Arc, - worker_id: usize, - timeout_seconds: u64, -) -> tokio::task::JoinHandle<()> { +pub fn spawn_worker(state: Arc, worker_id: usize) -> tokio::task::JoinHandle<()> { state.increment_active_workers(); tokio::spawn(async move { - worker_loop(state, worker_id, timeout_seconds).await; + worker_loop(state, worker_id).await; }) } -async fn worker_loop(state: Arc, _worker_id: usize, timeout_seconds: u64) { +async fn worker_loop(state: Arc, _worker_id: usize) { while let Some(task) = receive_task(&state).await { state.task_execution_tracker.start_task(&task.id).await; - let result = - process_task(&task, timeout_seconds, state.task_execution_tracker.clone()).await; + let result = process_task(&task, state.task_execution_tracker.clone()).await; if let Err(e) = state.result_sender.send(result).await { eprintln!("Worker failed to send result: {}", e); diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 75c08a42fa16..3e4727c4092e 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -144,6 +144,7 @@ pub struct SubRecipe { pub struct Execution { #[serde(default)] pub parallel: bool, + pub timeout_in_seconds: Option, #[serde(skip_serializing_if = "Option::is_none")] pub runs: Option>, } From 54d5224e70336feb24f545a43836874c9eb9d8aa Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 10:52:44 +1000 Subject: [PATCH 14/34] cleaned up some utils and added the error tool response --- .../goose-cli/src/recipes/extract_from_cli.rs | 1 + .../agents/recipe_tools/param_utils/tests.rs | 1 + .../agents/recipe_tools/sub_recipe_tools.rs | 5 +- .../recipe_tools/sub_recipe_tools/tests.rs | 1 + .../agents/sub_recipe_execution_tool/lib.rs | 38 +- .../sub_recipe_execute_task_tool.rs | 1 - .../sub_recipe_execution_tool/utils/mod.rs | 167 ------- .../sub_recipe_execution_tool/utils/tests.rs | 453 +----------------- crates/goose/src/recipe/mod.rs | 2 +- 9 files changed, 47 insertions(+), 622 deletions(-) diff --git a/crates/goose-cli/src/recipes/extract_from_cli.rs b/crates/goose-cli/src/recipes/extract_from_cli.rs index 84d578c0cae4..8ba8658c8cb6 100644 --- a/crates/goose-cli/src/recipes/extract_from_cli.rs +++ b/crates/goose-cli/src/recipes/extract_from_cli.rs @@ -31,6 +31,7 @@ pub fn extract_recipe_info_from_cli( let additional_sub_recipe = SubRecipe { path: recipe_file_path.to_string_lossy().to_string(), name, + timeout_in_seconds: None, values: None, executions: None, }; diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs index 81e5c412a8c7..e7b3f6878eb9 100644 --- a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -12,6 +12,7 @@ mod tests { let sub_recipe = SubRecipe { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), + timeout_in_seconds: None, values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), executions: None, }; diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 116e4a0a5193..88b100c5e67f 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -57,10 +57,7 @@ pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Re Task { id: uuid::Uuid::new_v4().to_string(), task_type: "sub_recipe".to_string(), - timeout_in_seconds: sub_recipe - .executions - .as_ref() - .and_then(|e| e.timeout_in_seconds), + timeout_in_seconds: sub_recipe.timeout_in_seconds, payload, } }) diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index a4956f65edda..ca66e97819bb 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -11,6 +11,7 @@ mod tests { let sub_recipe = SubRecipe { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), + timeout_in_seconds: None, values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), executions: None, }; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs index 6a18b24907a8..f51ceea71135 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs @@ -30,17 +30,45 @@ pub async fn execute_tasks( "sequential" => { if task_count == 1 { let response = execute_single_task(&tasks[0], notifier).await; - serde_json::to_value(response) - .map_err(|e| format!("Failed to serialize response: {}", e)) + handle_response(response) } else { Err("Sequential execution mode requires exactly one task".to_string()) } } "parallel" => { - let response = execute_tasks_in_parallel(tasks, config, notifier).await; - serde_json::to_value(response) - .map_err(|e| format!("Failed to serialize response: {}", e)) + let response: ExecutionResponse = execute_tasks_in_parallel(tasks, config, notifier).await; + handle_response(response) } _ => Err("Invalid execution mode".to_string()), } } + +fn handle_response(response: ExecutionResponse) -> Result { + if response.stats.failed > 0 { + let failed_tasks: Vec = response.results + .iter() + .filter(|r| matches!(r.status, TaskStatus::Failed)) + .map(|r| { + let error_msg = r.error.as_ref().map(|s| s.as_str()).unwrap_or("Unknown error"); + format!("Task '{}' ({}): {}", r.task_id, get_task_description(r), error_msg) + }) + .collect(); + + let error_summary = format!( + "{}/{} tasks failed:\n{}", + response.stats.failed, + response.stats.total_tasks, + failed_tasks.join("\n") + ); + + return Err(error_summary); + } + serde_json::to_value(response) + .map_err(|e| format!("Failed to serialize response: {}", e)) +} + +fn get_task_description(result: &TaskResult) -> String { + // We'd need to reconstruct task info from the result or pass it through + // For now, just use the task_id as placeholder + format!("ID: {}", result.task_id) +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 40fe117a05ba..f1ca709cb318 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -13,7 +13,6 @@ pub fn create_sub_recipe_execute_task_tool() -> Tool { Tool::new( SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, "Only use this tool when you execute sub recipe task. - EXECUTION STRATEGY DECISION: 1. PRE-CREATED TASKS: If tasks were created by subrecipe__create_task_* tools, check the execution_mode in the response: - If execution_mode is 'parallel', use parallel execution diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index 7b684126c383..bf806a50d182 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -1,15 +1,8 @@ use serde_json::{Map, Value}; use std::collections::HashMap; -use tokio::time::Instant; use crate::agents::sub_recipe_execution_tool::types::{TaskInfo, TaskStatus}; -// Constants for display formatting -const MAX_OUTPUT_LINES: usize = 2; -const OUTPUT_PREVIEW_LENGTH: usize = 100; -const ERROR_PREVIEW_LENGTH: usize = 80; -const CLEAR_TO_EOL: &str = "\x1b[K"; - pub fn get_task_name(task_info: &TaskInfo) -> &str { task_info .task @@ -21,14 +14,6 @@ pub fn get_command_parameters(task_info: &TaskInfo) -> Option<&Map String { - if text.len() > max_len { - format!("{}...", &text[..max_len.saturating_sub(3)]) - } else { - text.to_string() - } -} - pub fn count_by_status(tasks: &HashMap) -> (usize, usize, usize, usize, usize) { let total = tasks.len(); let (pending, running, completed, failed) = tasks.values().fold( @@ -43,157 +28,5 @@ pub fn count_by_status(tasks: &HashMap) -> (usize, usize, usiz (total, pending, running, completed, failed) } -pub fn strip_ansi_codes(text: &str) -> String { - let mut result = String::new(); - let mut chars = text.chars(); - - while let Some(ch) = chars.next() { - if ch == '\x1b' { - if chars.next() == Some('[') { - loop { - match chars.next() { - Some(c) if c.is_ascii_alphabetic() => break, - Some(_) => continue, - None => break, - } - } - } - } else { - result.push(ch); - } - } - - result -} - -pub fn get_status_icon(status: &TaskStatus) -> &'static str { - match status { - TaskStatus::Pending => "⏳", - TaskStatus::Running => "🏃", - TaskStatus::Completed => "✅", - TaskStatus::Failed => "❌", - } -} - -pub fn process_output_lines(output: &str) -> String { - let lines: Vec<&str> = output.lines().collect(); - let recent_lines = if lines.len() > MAX_OUTPUT_LINES { - &lines[lines.len() - MAX_OUTPUT_LINES..] - } else { - &lines - }; - - let clean_output = recent_lines.join("\n"); - strip_ansi_codes(&clean_output) -} - -/// Format task timing information -pub fn format_task_timing(task_info: &TaskInfo, current_time: Instant) -> Option { - task_info.start_time.map(|start_time| { - let duration = if let Some(end_time) = task_info.end_time { - end_time.duration_since(start_time) - } else { - current_time.duration_since(start_time) - }; - format!( - " ⏱️ {:.1}s{} -", - duration.as_secs_f64(), - CLEAR_TO_EOL - ) - }) -} - -/// Format task output preview -pub fn format_task_output(task_info: &TaskInfo) -> Option { - if matches!(task_info.status, TaskStatus::Running) && !task_info.current_output.is_empty() { - let output_preview = - truncate_with_ellipsis(&task_info.current_output, OUTPUT_PREVIEW_LENGTH); - Some(format!( - " 💬 {}{} -", - output_preview.replace('\n', " | "), - CLEAR_TO_EOL - )) - } else { - None - } -} - -/// Format task error information -pub fn format_task_error(task_info: &TaskInfo) -> Option { - task_info.error().map(|error| { - let error_preview = truncate_with_ellipsis(error, ERROR_PREVIEW_LENGTH); - format!( - " ⚠️ {}{} -", - error_preview.replace('\n', " "), - CLEAR_TO_EOL - ) - }) -} - -pub fn format_command_parameters(task_info: &TaskInfo) -> Option { - get_command_parameters(task_info).map(|params| { - if params.is_empty() { - return format!(" 📋 Parameters: (none){}\n", CLEAR_TO_EOL); - } - - let params_str = params - .iter() - .map(|(key, value)| { - let value_str = match value { - Value::String(s) => s.clone(), - _ => value.to_string(), - }; - format!("{}={}", key, value_str) - }) - .collect::>() - .join(", "); - - let params_preview = truncate_with_ellipsis(¶ms_str, OUTPUT_PREVIEW_LENGTH); - format!(" 📋 Parameters: {}{}\n", params_preview, CLEAR_TO_EOL) - }) -} - -pub fn format_task_display(task_info: &TaskInfo, current_time: Instant) -> String { - let mut display = String::new(); - - let status_icon = get_status_icon(&task_info.status); - let task_name = get_task_name(task_info); - - // Task status line - display.push_str(&format!( - "{} {} ({}){} -", - status_icon, task_name, task_info.task.task_type, CLEAR_TO_EOL - )); - - if let Some(params) = format_command_parameters(task_info) { - display.push_str(¶ms); - } - - if let Some(timing) = format_task_timing(task_info, current_time) { - display.push_str(&timing); - } - - if let Some(output) = format_task_output(task_info) { - display.push_str(&output); - } - - // Task error (if failed) - if let Some(error) = format_task_error(task_info) { - display.push_str(&error); - } - - display.push_str(&format!( - "{} -", - CLEAR_TO_EOL - )); - - display -} - #[cfg(test)] mod tests; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs index 88d823092c34..f921f131781f 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -1,101 +1,20 @@ use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskStatus}; use crate::agents::sub_recipe_execution_tool::utils::{ - count_by_status, format_task_display, format_task_error, format_task_output, - format_task_timing, get_status_icon, get_task_name, process_output_lines, strip_ansi_codes, - truncate_with_ellipsis, + count_by_status, + format_task_timing, get_status_icon, get_task_name, }; use serde_json::json; use std::collections::HashMap; -use tokio::time::Instant; -mod truncate_with_ellipsis { +mod test_get_task_name { use super::*; #[test] - fn returns_original_when_under_limit() { - assert_eq!(truncate_with_ellipsis("hello", 10), "hello"); - assert_eq!(truncate_with_ellipsis("hi", 5), "hi"); - } - - #[test] - fn truncates_when_over_limit() { - assert_eq!(truncate_with_ellipsis("hello world", 5), "he..."); - assert_eq!( - truncate_with_ellipsis("very long text here", 10), - "very lo..." - ); - } - - #[test] - fn handles_empty_string() { - assert_eq!(truncate_with_ellipsis("", 5), ""); - } - - #[test] - fn handles_exact_limit() { - assert_eq!(truncate_with_ellipsis("hello", 5), "hello"); - } - - #[test] - fn handles_very_short_limit() { - assert_eq!(truncate_with_ellipsis("hello", 3), "..."); - assert_eq!(truncate_with_ellipsis("hi", 2), "hi"); // Under limit, return as-is - } -} - -mod strip_ansi_codes { - use super::*; - - #[test] - fn preserves_plain_text() { - assert_eq!(strip_ansi_codes("hello world"), "hello world"); - assert_eq!(strip_ansi_codes("no ansi codes"), "no ansi codes"); - } - - #[test] - fn removes_color_codes() { - assert_eq!(strip_ansi_codes("\x1b[31mred text\x1b[0m"), "red text"); - assert_eq!(strip_ansi_codes("\x1b[32mgreen\x1b[0m"), "green"); - } - - #[test] - fn removes_complex_formatting() { - assert_eq!( - strip_ansi_codes("\x1b[1;32mbold green\x1b[0m"), - "bold green" - ); - assert_eq!( - strip_ansi_codes("\x1b[4;31munderline red\x1b[0m"), - "underline red" - ); - } - - #[test] - fn handles_multiple_sequences() { - let input = "\x1b[31mred\x1b[0m normal \x1b[32mgreen\x1b[0m"; - assert_eq!(strip_ansi_codes(input), "red normal green"); - } - - #[test] - fn handles_empty_string() { - assert_eq!(strip_ansi_codes(""), ""); - } - - #[test] - fn handles_malformed_sequences() { - // Incomplete escape sequence - our function strips the \x1b char - assert_eq!(strip_ansi_codes("\x1b hello"), "hello"); - } -} - -mod get_task_name { - use super::*; - - #[test] - fn extracts_sub_recipe_name() { + fn test_extracts_sub_recipe_name() { let sub_recipe_task = Task { id: "task_1".to_string(), task_type: "sub_recipe".to_string(), + timeout_in_seconds: None, payload: json!({ "sub_recipe": { "name": "my_recipe", @@ -120,6 +39,7 @@ mod get_task_name { fn falls_back_to_task_id_for_text_instruction() { let text_task = Task { id: "task_2".to_string(), + timeout_in_seconds: None, task_type: "text_instruction".to_string(), payload: json!({"text_instruction": "do something"}), }; @@ -141,6 +61,7 @@ mod get_task_name { let malformed_task = Task { id: "task_3".to_string(), task_type: "sub_recipe".to_string(), + timeout_in_seconds: None, payload: json!({ "sub_recipe": { "recipe_path": "/path/to/recipe" @@ -166,6 +87,7 @@ mod get_task_name { let malformed_task = Task { id: "task_4".to_string(), task_type: "sub_recipe".to_string(), + timeout_in_seconds: None, payload: json!({}), // missing "sub_recipe" field }; @@ -190,6 +112,7 @@ mod count_by_status { task: Task { id: id.to_string(), task_type: "test".to_string(), + timeout_in_seconds: None, payload: json!({}), }, status, @@ -260,361 +183,3 @@ mod count_by_status { ); } } - -mod get_status_icon { - use super::*; - - #[test] - fn returns_correct_icon_for_pending() { - assert_eq!(get_status_icon(&TaskStatus::Pending), "⏳"); - } - - #[test] - fn returns_correct_icon_for_running() { - assert_eq!(get_status_icon(&TaskStatus::Running), "🏃"); - } - - #[test] - fn returns_correct_icon_for_completed() { - assert_eq!(get_status_icon(&TaskStatus::Completed), "✅"); - } - - #[test] - fn returns_correct_icon_for_failed() { - assert_eq!(get_status_icon(&TaskStatus::Failed), "❌"); - } -} - -mod process_output_lines { - use super::*; - - #[test] - fn preserves_short_output() { - let output = "line 1\nline 2"; - assert_eq!(process_output_lines(output), "line 1\nline 2"); - } - - #[test] - fn keeps_only_recent_lines_when_too_many() { - let output = "line 1\nline 2\nline 3\nline 4\nline 5"; - let result = process_output_lines(output); - assert_eq!(result, "line 4\nline 5"); - } - - #[test] - fn strips_ansi_codes_from_output() { - let output = "\x1b[31mred line 1\x1b[0m\n\x1b[32mgreen line 2\x1b[0m"; - let result = process_output_lines(output); - assert_eq!(result, "red line 1\ngreen line 2"); - } - - #[test] - fn handles_empty_output() { - assert_eq!(process_output_lines(""), ""); - } - - #[test] - fn handles_single_line() { - assert_eq!(process_output_lines("single line"), "single line"); - } - - #[test] - fn combines_ansi_stripping_and_line_limiting() { - let output = "\x1b[31mline 1\x1b[0m\n\x1b[32mline 2\x1b[0m\n\x1b[33mline 3\x1b[0m\n\x1b[34mline 4\x1b[0m"; - let result = process_output_lines(output); - assert_eq!(result, "line 3\nline 4"); - } -} - -mod format_task_timing { - use super::*; - use std::time::Duration; - - fn create_test_task_info_with_timing(start: Option, end: Option) -> TaskInfo { - TaskInfo { - task: Task { - id: "test_task".to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status: TaskStatus::Running, - start_time: start, - end_time: end, - result: None, - current_output: String::new(), - } - } - - #[test] - fn returns_none_when_no_start_time() { - let task_info = create_test_task_info_with_timing(None, None); - let current_time = Instant::now(); - assert!(format_task_timing(&task_info, current_time).is_none()); - } - - #[test] - fn formats_running_task_duration() { - let start_time = Instant::now(); - let current_time = start_time + Duration::from_millis(1500); - let task_info = create_test_task_info_with_timing(Some(start_time), None); - - let result = format_task_timing(&task_info, current_time).unwrap(); - assert!(result.contains("1.5s")); - assert!(result.contains("⏱️")); - } - - #[test] - fn formats_completed_task_duration() { - let start_time = Instant::now(); - let end_time = start_time + Duration::from_millis(2500); - let current_time = Instant::now(); // This shouldn't matter for completed tasks - let task_info = create_test_task_info_with_timing(Some(start_time), Some(end_time)); - - let result = format_task_timing(&task_info, current_time).unwrap(); - assert!(result.contains("2.5s")); - assert!(result.contains("⏱️")); - } -} - -mod format_task_output { - use super::*; - - fn create_test_task_info_with_output(status: TaskStatus, output: &str) -> TaskInfo { - TaskInfo { - task: Task { - id: "test_task".to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status, - start_time: None, - end_time: None, - result: None, - current_output: output.to_string(), - } - } - - #[test] - fn returns_none_for_non_running_tasks() { - let task_info = create_test_task_info_with_output(TaskStatus::Pending, "some output"); - assert!(format_task_output(&task_info).is_none()); - - let task_info = create_test_task_info_with_output(TaskStatus::Completed, "some output"); - assert!(format_task_output(&task_info).is_none()); - - let task_info = create_test_task_info_with_output(TaskStatus::Failed, "some output"); - assert!(format_task_output(&task_info).is_none()); - } - - #[test] - fn returns_none_for_running_task_with_empty_output() { - let task_info = create_test_task_info_with_output(TaskStatus::Running, ""); - assert!(format_task_output(&task_info).is_none()); - } - - #[test] - fn formats_running_task_with_output() { - let task_info = - create_test_task_info_with_output(TaskStatus::Running, "Building project..."); - let result = format_task_output(&task_info).unwrap(); - - assert!(result.contains("💬")); - assert!(result.contains("Building project...")); - } - - #[test] - fn replaces_newlines_with_pipes() { - let task_info = - create_test_task_info_with_output(TaskStatus::Running, "line 1\nline 2\nline 3"); - let result = format_task_output(&task_info).unwrap(); - - assert!(result.contains("line 1 | line 2 | line 3")); - } - - #[test] - fn truncates_long_output() { - let long_output = "a".repeat(150); - let task_info = create_test_task_info_with_output(TaskStatus::Running, &long_output); - let result = format_task_output(&task_info).unwrap(); - - assert!(result.contains("...")); - assert!(result.len() < long_output.len() + 20); // Account for formatting - } -} - -mod format_task_error { - use super::*; - use crate::agents::sub_recipe_execution_tool::types::{TaskResult, TaskStatus}; - - fn create_test_task_info_with_error(error_msg: Option<&str>) -> TaskInfo { - let result = error_msg.map(|msg| TaskResult { - task_id: "test_task".to_string(), - status: TaskStatus::Failed, - data: None, - error: Some(msg.to_string()), - }); - - TaskInfo { - task: Task { - id: "test_task".to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status: TaskStatus::Failed, - start_time: None, - end_time: None, - result, - current_output: String::new(), - } - } - - #[test] - fn returns_none_when_no_error() { - let task_info = create_test_task_info_with_error(None); - assert!(format_task_error(&task_info).is_none()); - } - - #[test] - fn formats_error_message() { - let task_info = create_test_task_info_with_error(Some("File not found")); - let result = format_task_error(&task_info).unwrap(); - - assert!(result.contains("⚠️")); - assert!(result.contains("File not found")); - } - - #[test] - fn replaces_newlines_in_error() { - let task_info = create_test_task_info_with_error(Some("Error on line 1\nError on line 2")); - let result = format_task_error(&task_info).unwrap(); - - assert!(result.contains("Error on line 1 Error on line 2")); - } - - #[test] - fn truncates_long_error() { - let long_error = "error ".repeat(30); - let task_info = create_test_task_info_with_error(Some(&long_error)); - let result = format_task_error(&task_info).unwrap(); - - assert!(result.contains("...")); - assert!(result.len() < long_error.len() + 20); // Account for formatting - } -} - -mod format_task_display { - use super::*; - use std::time::Duration; - - fn create_comprehensive_task_info( - task_name: &str, - status: TaskStatus, - start_time: Option, - end_time: Option, - current_output: &str, - error: Option<&str>, - ) -> TaskInfo { - let result = error.map( - |msg| crate::agents::sub_recipe_execution_tool::types::TaskResult { - task_id: task_name.to_string(), - status: status.clone(), - data: None, - error: Some(msg.to_string()), - }, - ); - - TaskInfo { - task: Task { - id: task_name.to_string(), - task_type: "test".to_string(), - payload: json!({}), - }, - status, - start_time, - end_time, - result, - current_output: current_output.to_string(), - } - } - - #[test] - fn formats_pending_task() { - let task_info = create_comprehensive_task_info( - "pending_task", - TaskStatus::Pending, - None, - None, - "", - None, - ); - let current_time = Instant::now(); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("⏳")); - assert!(result.contains("pending_task")); - assert!(result.contains("(test)")); - } - - #[test] - fn formats_running_task_with_output() { - let start_time = Instant::now(); - let current_time = start_time + Duration::from_secs(2); - let task_info = create_comprehensive_task_info( - "running_task", - TaskStatus::Running, - Some(start_time), - None, - "Compiling...", - None, - ); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("🏃")); - assert!(result.contains("running_task")); - assert!(result.contains("2.0s")); - assert!(result.contains("💬")); - assert!(result.contains("Compiling...")); - } - - #[test] - fn formats_failed_task_with_error() { - let start_time = Instant::now(); - let end_time = start_time + Duration::from_millis(1500); - let task_info = create_comprehensive_task_info( - "failed_task", - TaskStatus::Failed, - Some(start_time), - Some(end_time), - "", - Some("Compilation failed"), - ); - let current_time = Instant::now(); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("❌")); - assert!(result.contains("failed_task")); - assert!(result.contains("1.5s")); - assert!(result.contains("⚠️")); - assert!(result.contains("Compilation failed")); - } - - #[test] - fn formats_completed_task() { - let start_time = Instant::now(); - let end_time = start_time + Duration::from_secs(3); - let task_info = create_comprehensive_task_info( - "completed_task", - TaskStatus::Completed, - Some(start_time), - Some(end_time), - "", - None, - ); - let current_time = Instant::now(); - let result = format_task_display(&task_info, current_time); - - assert!(result.contains("✅")); - assert!(result.contains("completed_task")); - assert!(result.contains("3.0s")); - } -} diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 3e4727c4092e..fe4e514d02a8 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -135,6 +135,7 @@ pub struct Response { pub struct SubRecipe { pub name: String, pub path: String, + pub timeout_in_seconds: Option, #[serde(default, deserialize_with = "deserialize_value_map_as_string")] pub values: Option>, #[serde(skip_serializing_if = "Option::is_none")] @@ -144,7 +145,6 @@ pub struct SubRecipe { pub struct Execution { #[serde(default)] pub parallel: bool, - pub timeout_in_seconds: Option, #[serde(skip_serializing_if = "Option::is_none")] pub runs: Option>, } From 47e769ec0f8121703cb704303b7e794a6e174dd5 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 13:25:19 +1000 Subject: [PATCH 15/34] return task output if task times out --- .../agents/sub_recipe_execution_tool/lib.rs | 31 ++++++++++++++----- .../task_execution_tracker.rs | 7 +++++ .../agents/sub_recipe_execution_tool/tasks.rs | 23 +++++++++----- .../sub_recipe_execution_tool/utils/tests.rs | 3 +- 4 files changed, 47 insertions(+), 17 deletions(-) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs index f51ceea71135..d3b1f51bb517 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs @@ -36,7 +36,8 @@ pub async fn execute_tasks( } } "parallel" => { - let response: ExecutionResponse = execute_tasks_in_parallel(tasks, config, notifier).await; + let response: ExecutionResponse = + execute_tasks_in_parallel(tasks, config, notifier).await; handle_response(response) } _ => Err("Invalid execution mode".to_string()), @@ -45,26 +46,40 @@ pub async fn execute_tasks( fn handle_response(response: ExecutionResponse) -> Result { if response.stats.failed > 0 { - let failed_tasks: Vec = response.results + let failed_tasks: Vec = response + .results .iter() .filter(|r| matches!(r.status, TaskStatus::Failed)) .map(|r| { - let error_msg = r.error.as_ref().map(|s| s.as_str()).unwrap_or("Unknown error"); - format!("Task '{}' ({}): {}", r.task_id, get_task_description(r), error_msg) + let error_msg = r.error.as_deref().unwrap_or("Unknown error"); + let partial_output = r + .data + .as_ref() + .and_then(|d| d.get("partial_output")) + .and_then(|v| v.as_str()) + .filter(|s| !s.trim().is_empty()) + .unwrap_or("No output captured"); + + format!( + "Task '{}' ({}): {}, \noutput: {}", + r.task_id, + get_task_description(r), + error_msg, + partial_output + ) }) .collect(); - + let error_summary = format!( "{}/{} tasks failed:\n{}", response.stats.failed, response.stats.total_tasks, failed_tasks.join("\n") ); - + return Err(error_summary); } - serde_json::to_value(response) - .map_err(|e| format!("Failed to serialize response: {}", e)) + serde_json::to_value(response).map_err(|e| format!("Failed to serialize response: {}", e)) } fn get_task_description(result: &TaskResult) -> String { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs index 59325e18bf75..9d18f89ba7fa 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -99,6 +99,13 @@ impl TaskExecutionTracker { self.refresh_display().await; } + pub async fn get_current_output(&self, task_id: &str) -> Option { + let tasks = self.tasks.read().await; + tasks + .get(task_id) + .map(|task_info| task_info.current_output.clone()) + } + pub async fn send_live_output(&self, task_id: &str, line: &str) { match self.display_mode { DisplayMode::SingleTaskOutput => { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index f5762f973855..428f7bc81f8a 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -17,6 +17,7 @@ pub async fn process_task( let task_clone = task.clone(); let timeout_duration = Duration::from_secs(timeout_in_seconds); + let task_execution_tracker_clone = task_execution_tracker.clone(); match timeout( timeout_duration, get_task_result(task_clone, task_execution_tracker), @@ -35,12 +36,21 @@ pub async fn process_task( data: None, error: Some(error), }, - Err(_) => TaskResult { - task_id: task.id.clone(), - status: TaskStatus::Failed, - data: None, - error: Some("Task timeout".to_string()), - }, + Err(_) => { + let current_output = task_execution_tracker_clone + .get_current_output(&task.id) + .await + .unwrap_or_default(); + + TaskResult { + task_id: task.id.clone(), + status: TaskStatus::Failed, + data: Some(serde_json::json!({ + "partial_output": current_output + })), + error: Some(format!("Task timed out after {}s", timeout_in_seconds)), + } + } } } @@ -158,7 +168,6 @@ fn spawn_output_reader( buffer.push('\n'); if !is_stderr { - // Use dashboard's smart output handling based on display mode task_execution_tracker .send_live_output(&task_id, &line) .await; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs index f921f131781f..e05e00d3ee13 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -1,7 +1,6 @@ use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskStatus}; use crate::agents::sub_recipe_execution_tool::utils::{ - count_by_status, - format_task_timing, get_status_icon, get_task_name, + count_by_status, format_task_timing, get_status_icon, get_task_name, }; use serde_json::json; use std::collections::HashMap; From 85d5286e3f61a25de503f27db8bd402c3fd44ab8 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 14:05:02 +1000 Subject: [PATCH 16/34] changed console output a bit --- .../src/session/task_execution_display.rs | 30 ++++++++++++++++++- .../task_execution_tracker.rs | 26 +++++++++++++--- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/crates/goose-cli/src/session/task_execution_display.rs b/crates/goose-cli/src/session/task_execution_display.rs index 38542ba2c29c..38a168241ac9 100644 --- a/crates/goose-cli/src/session/task_execution_display.rs +++ b/crates/goose-cli/src/session/task_execution_display.rs @@ -101,6 +101,15 @@ fn format_task_from_json(task: &Value) -> String { } } + if status == "Completed" { + if let Some(result_data) = task.get("result_data") { + let result_preview = format_result_data_for_display(result_data); + if !result_preview.is_empty() { + task_display.push_str(&format!(" 📄 {}{}\n", result_preview, CLEAR_TO_EOL)); + } + } + } + if status == "Failed" { if let Some(error) = task.get("error").and_then(|v| v.as_str()) { let error_preview = truncate_with_ellipsis(error, 80); @@ -116,6 +125,25 @@ fn format_task_from_json(task: &Value) -> String { task_display } +fn format_result_data_for_display(result_data: &Value) -> String { + match result_data { + Value::String(s) => strip_ansi_codes(s), + Value::Object(obj) => { + // Handle specific result formats + if let Some(partial_output) = obj.get("partial_output").and_then(|v| v.as_str()) { + format!("Partial output: {}", partial_output) + } else { + // Generic object display + serde_json::to_string_pretty(obj).unwrap_or_default() + } + } + Value::Array(arr) => serde_json::to_string_pretty(arr).unwrap_or_default(), + Value::Bool(b) => b.to_string(), + Value::Number(n) => n.to_string(), + Value::Null => "null".to_string(), + } +} + fn process_output_for_display(output: &str) -> String { const MAX_OUTPUT_LINES: usize = 2; const OUTPUT_PREVIEW_LENGTH: usize = 100; @@ -127,7 +155,7 @@ fn process_output_for_display(output: &str) -> String { &lines }; - let clean_output = recent_lines.join(" | "); + let clean_output = recent_lines.join(" ... "); let stripped = strip_ansi_codes(&clean_output); truncate_with_ellipsis(&stripped, OUTPUT_PREVIEW_LENGTH) } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs index 9d18f89ba7fa..e3769dd84760 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -15,7 +15,7 @@ pub enum DisplayMode { SingleTaskOutput, } -const THROTTLE_INTERVAL_MS: u64 = 1000; +const THROTTLE_INTERVAL_MS: u64 = 250; fn format_task_metadata(task_info: &TaskInfo) -> String { if let Some(params) = task_info.task.get_command_parameters() { @@ -85,7 +85,7 @@ impl TaskExecutionTracker { task_info.start_time = Some(Instant::now()); } drop(tasks); - self.refresh_display().await; + self.force_refresh_display().await; } pub async fn complete_task(&self, task_id: &str, result: TaskResult) { @@ -96,7 +96,7 @@ impl TaskExecutionTracker { task_info.result = Some(result); } drop(tasks); - self.refresh_display().await; + self.force_refresh_display().await; } pub async fn get_current_output(&self, task_id: &str) -> Option { @@ -189,7 +189,8 @@ impl TaskExecutionTracker { "task_type": task_info.task.task_type, "task_name": get_task_name(task_info), "task_metadata": format_task_metadata(task_info), - "error": task_info.error() + "error": task_info.error(), + "result_data": task_info.data() }) }).collect::>() } @@ -209,6 +210,23 @@ impl TaskExecutionTracker { } } + // Force refresh without throttling - used for important status changes + async fn force_refresh_display(&self) { + match self.display_mode { + DisplayMode::MultipleTasksOutput => { + // Reset throttle timer to allow immediate update + let mut last_refresh = self.last_refresh.write().await; + *last_refresh = Instant::now() - Duration::from_millis(THROTTLE_INTERVAL_MS + 1); + drop(last_refresh); + + self.send_tasks_update().await; + } + DisplayMode::SingleTaskOutput => { + // No dashboard display needed for single task output mode + } + } + } + pub async fn send_tasks_complete(&self) { let tasks = self.tasks.read().await; let (total, _, _, completed, failed) = count_by_status(&tasks); From 2ccfb6809d53194335f0e297875d9dad950f2380 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 14:12:46 +1000 Subject: [PATCH 17/34] test fmt --- crates/goose/src/agents/recipe_tools/param_utils/tests.rs | 1 - .../goose/src/agents/sub_recipe_execution_tool/utils/tests.rs | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs index e7b3f6878eb9..fcbab0f6e9fb 100644 --- a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -4,7 +4,6 @@ mod tests { use crate::recipe::{Execution, ExecutionRun, SubRecipe}; use serde_json::json; - use serde_json::Value; use crate::agents::recipe_tools::param_utils::prepare_command_params; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs index e05e00d3ee13..0e0500c490ec 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -1,7 +1,5 @@ use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskStatus}; -use crate::agents::sub_recipe_execution_tool::utils::{ - count_by_status, format_task_timing, get_status_icon, get_task_name, -}; +use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; use serde_json::json; use std::collections::HashMap; From 445399892a893d245f29f78b8e4987ba3c28f627 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 16:11:29 +1000 Subject: [PATCH 18/34] refactored the code --- .../src/session/task_execution_display.rs | 19 +++-- .../agents/recipe_tools/param_utils/mod.rs | 9 +-- .../agents/recipe_tools/param_utils/tests.rs | 11 ++- .../agents/recipe_tools/sub_recipe_tools.rs | 45 +++++++++--- .../sub_recipe_execution_tool/executor.rs | 7 +- .../agents/sub_recipe_execution_tool/lib.rs | 69 +++++++++++-------- .../task_execution_tracker.rs | 22 ++++-- .../agents/sub_recipe_execution_tool/tasks.rs | 45 ++++++++---- .../agents/sub_recipe_execution_tool/types.rs | 15 ++-- .../sub_recipe_execution_tool/utils/mod.rs | 3 - .../sub_recipe_execution_tool/utils/tests.rs | 67 ++++++------------ .../sub_recipe_execution_tool/workers.rs | 2 +- 12 files changed, 173 insertions(+), 141 deletions(-) diff --git a/crates/goose-cli/src/session/task_execution_display.rs b/crates/goose-cli/src/session/task_execution_display.rs index 38a168241ac9..547c41b09e88 100644 --- a/crates/goose-cli/src/session/task_execution_display.rs +++ b/crates/goose-cli/src/session/task_execution_display.rs @@ -1,4 +1,5 @@ use serde_json::Value; +use std::sync::atomic::{AtomicBool, Ordering}; const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H"; const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H"; @@ -6,19 +7,17 @@ const CLEAR_TO_EOL: &str = "\x1b[K"; const CLEAR_BELOW: &str = "\x1b[J"; pub const TASK_EXECUTION_NOTIFICATION_TYPE: &str = "task_execution"; +static INITIAL_SHOWN: AtomicBool = AtomicBool::new(false); + pub fn format_tasks_update(data: &Value) -> String { let mut display = String::new(); - static mut INITIAL_SHOWN: bool = false; - unsafe { - if !INITIAL_SHOWN { - display.push_str(CLEAR_SCREEN); - display.push_str("🎯 Task Execution Dashboard\n"); - display.push_str("═══════════════════════════\n\n"); - INITIAL_SHOWN = true; - } else { - display.push_str(MOVE_TO_PROGRESS_LINE); - } + if !INITIAL_SHOWN.swap(true, Ordering::SeqCst) { + display.push_str(CLEAR_SCREEN); + display.push_str("🎯 Task Execution Dashboard\n"); + display.push_str("═══════════════════════════\n\n"); + } else { + display.push_str(MOVE_TO_PROGRESS_LINE); } if let Some(stats) = data.get("stats") { diff --git a/crates/goose/src/agents/recipe_tools/param_utils/mod.rs b/crates/goose/src/agents/recipe_tools/param_utils/mod.rs index 08db06c5e25e..e7891c1bb5e7 100644 --- a/crates/goose/src/agents/recipe_tools/param_utils/mod.rs +++ b/crates/goose/src/agents/recipe_tools/param_utils/mod.rs @@ -32,10 +32,11 @@ pub fn validate_param_counts( run_params: &[HashMap], params_from_tool_call: &[Value], ) -> Result<()> { - if !run_params.is_empty() - && run_params.len() != params_from_tool_call.len() - && params_from_tool_call.len() > 1 - { + let has_run_params = !run_params.is_empty(); + let multiple_params_from_tool_call = params_from_tool_call.len() > 1; + let count_mismatch = run_params.len() != params_from_tool_call.len(); + + if has_run_params && multiple_params_from_tool_call && count_mismatch { return Err(anyhow::anyhow!( "The number of runs in the sub recipe ({}) does not match the number of task parameters ({})", run_params.len(), diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs index fcbab0f6e9fb..0a52da64f428 100644 --- a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -1,11 +1,9 @@ -#[cfg(test)] -mod tests { - use std::collections::HashMap; +use std::collections::HashMap; - use crate::recipe::{Execution, ExecutionRun, SubRecipe}; - use serde_json::json; +use crate::recipe::{Execution, ExecutionRun, SubRecipe}; +use serde_json::json; - use crate::agents::recipe_tools::param_utils::prepare_command_params; +use crate::agents::recipe_tools::param_utils::prepare_command_params; fn setup_default_sub_recipe() -> SubRecipe { let sub_recipe = SubRecipe { @@ -233,4 +231,3 @@ mod tests { } } } -} diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 88b100c5e67f..dac0825beda2 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -11,6 +11,8 @@ use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubReci use super::param_utils::prepare_command_params; pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; +const EXECUTION_MODE_PARALLEL: &str = "parallel"; +const EXECUTION_MODE_SEQUENTIAL: &str = "sequential"; pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { let input_schema = get_input_schema(sub_recipe).unwrap(); @@ -37,14 +39,19 @@ pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { ) } -pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result { +fn extract_task_parameters(params: &Value) -> &Vec { let empty_vec = vec![]; - let task_params_array = params + params .get("task_parameters") .and_then(|v| v.as_array()) - .unwrap_or(&empty_vec); - let command_params = prepare_command_params(sub_recipe, task_params_array.clone())?; - let tasks = command_params + .unwrap_or(&empty_vec) +} + +fn create_tasks_from_params( + sub_recipe: &SubRecipe, + command_params: &[std::collections::HashMap], +) -> Vec { + command_params .iter() .map(|task_command_param| { let payload = json!({ @@ -61,16 +68,36 @@ pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Re payload, } }) - .collect::>(); + .collect() +} + +fn get_execution_mode(sub_recipe: &SubRecipe) -> &'static str { let is_parallel = sub_recipe .executions .as_ref() .map(|e| e.parallel) .unwrap_or(false); - let task_execution_payload = json!({ + + if is_parallel { + EXECUTION_MODE_PARALLEL + } else { + EXECUTION_MODE_SEQUENTIAL + } +} + +fn create_task_execution_payload(tasks: Vec, execution_mode: &str) -> Value { + json!({ "tasks": tasks, - "execution_mode": if is_parallel { "parallel" } else { "sequential" } - }); + "execution_mode": execution_mode + }) +} + +pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result { + let task_params_array = extract_task_parameters(¶ms); + let command_params = prepare_command_params(sub_recipe, task_params_array.clone())?; + let tasks = create_tasks_from_params(sub_recipe, &command_params); + let execution_mode = get_execution_mode(sub_recipe); + let task_execution_payload = create_task_execution_payload(tasks, execution_mode); let tasks_json = serde_json::to_string(&task_execution_payload) .map_err(|e| anyhow::anyhow!("Failed to serialize task list: {}", e))?; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs index 4ec7443a40b4..674bd13ae5e9 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs @@ -58,7 +58,7 @@ pub async fn execute_tasks_in_parallel( let (task_tx, task_rx, result_tx, mut result_rx) = create_channels(task_count); if let Err(e) = send_tasks_to_channel(tasks, task_tx).await { - eprintln!("Execution failed: {}", e); + tracing::error!("Task execution failed: {}", e); return create_error_response(e); } @@ -76,7 +76,7 @@ pub async fn execute_tasks_in_parallel( for handle in worker_handles { if let Err(e) = handle.await { - eprintln!("Worker error: {}", e); + tracing::error!("Worker error: {}", e); } } @@ -180,7 +180,8 @@ async fn collect_results( results } -fn create_error_response(_error: String) -> ExecutionResponse { +fn create_error_response(error: String) -> ExecutionResponse { + tracing::error!("Creating error response: {}", error); ExecutionResponse { status: "failed".to_string(), results: vec![], diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs index d3b1f51bb517..e1e20d772969 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs @@ -44,46 +44,55 @@ pub async fn execute_tasks( } } -fn handle_response(response: ExecutionResponse) -> Result { - if response.stats.failed > 0 { - let failed_tasks: Vec = response - .results - .iter() - .filter(|r| matches!(r.status, TaskStatus::Failed)) - .map(|r| { - let error_msg = r.error.as_deref().unwrap_or("Unknown error"); - let partial_output = r - .data - .as_ref() - .and_then(|d| d.get("partial_output")) - .and_then(|v| v.as_str()) - .filter(|s| !s.trim().is_empty()) - .unwrap_or("No output captured"); +fn extract_failed_tasks(results: &[TaskResult]) -> Vec { + results + .iter() + .filter(|r| matches!(r.status, TaskStatus::Failed)) + .map(|r| format_failed_task_error(r)) + .collect() +} + +fn format_failed_task_error(result: &TaskResult) -> String { + let error_msg = result.error.as_deref().unwrap_or("Unknown error"); + let partial_output = result + .data + .as_ref() + .and_then(|d| d.get("partial_output")) + .and_then(|v| v.as_str()) + .filter(|s| !s.trim().is_empty()) + .unwrap_or("No output captured"); + + format!( + "Task '{}' ({}): {}\nOutput: {}", + result.task_id, + get_task_description(result), + error_msg, + partial_output + ) +} - format!( - "Task '{}' ({}): {}, \noutput: {}", - r.task_id, - get_task_description(r), - error_msg, - partial_output - ) - }) - .collect(); +fn format_error_summary(failed_count: usize, total_count: usize, failed_tasks: Vec) -> String { + format!( + "{}/{} tasks failed:\n{}", + failed_count, + total_count, + failed_tasks.join("\n") + ) +} - let error_summary = format!( - "{}/{} tasks failed:\n{}", +fn handle_response(response: ExecutionResponse) -> Result { + if response.stats.failed > 0 { + let failed_tasks = extract_failed_tasks(&response.results); + let error_summary = format_error_summary( response.stats.failed, response.stats.total_tasks, - failed_tasks.join("\n") + failed_tasks, ); - return Err(error_summary); } serde_json::to_value(response).map_err(|e| format!("Failed to serialize response: {}", e)) } fn get_task_description(result: &TaskResult) -> String { - // We'd need to reconstruct task info from the result or pass it through - // For now, just use the task_id as placeholder format!("ID: {}", result.task_id) } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs index e3769dd84760..96176fc77930 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -16,6 +16,7 @@ pub enum DisplayMode { } const THROTTLE_INTERVAL_MS: u64 = 250; +const COMPLETION_NOTIFICATION_DELAY_MS: u64 = 500; fn format_task_metadata(task_info: &TaskInfo) -> String { if let Some(params) = task_info.task.get_command_parameters() { @@ -110,7 +111,7 @@ impl TaskExecutionTracker { match self.display_mode { DisplayMode::SingleTaskOutput => { // Send raw output data - let subscriber format it - let _ = self + if let Err(e) = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { jsonrpc: "2.0".to_string(), @@ -123,7 +124,9 @@ impl TaskExecutionTracker { "output": line } })), - })); + })) { + tracing::warn!("Failed to send live output notification: {}", e); + } } DisplayMode::MultipleTasksOutput => { let mut tasks = self.tasks.write().await; @@ -157,7 +160,7 @@ impl TaskExecutionTracker { let task_list: Vec<_> = tasks.values().collect(); let (total, pending, running, completed, failed) = count_by_status(&tasks); - let _ = self + if let Err(e) = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { jsonrpc: "2.0".to_string(), @@ -195,7 +198,9 @@ impl TaskExecutionTracker { }).collect::>() } })), - })); + })) { + tracing::warn!("Failed to send tasks update notification: {}", e); + } } pub async fn refresh_display(&self) { @@ -244,7 +249,7 @@ impl TaskExecutionTracker { }) .collect(); - let _ = self + if let Err(e) = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { jsonrpc: "2.0".to_string(), @@ -262,8 +267,11 @@ impl TaskExecutionTracker { "failed_tasks": failed_tasks } })), - })); + })) { + tracing::warn!("Failed to send tasks complete notification: {}", e); + } - sleep(Duration::from_millis(500)).await; + // Brief delay to ensure completion notification is processed + sleep(Duration::from_millis(COMPLETION_NOTIFICATION_DELAY_MS)).await; } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index 428f7bc81f8a..43443804ddd9 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -9,11 +9,13 @@ use tokio::time::timeout; use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult, TaskStatus}; +const DEFAULT_TASK_TIMEOUT_SECONDS: u64 = 300; + pub async fn process_task( task: &Task, task_execution_tracker: Arc, ) -> TaskResult { - let timeout_in_seconds = task.timeout_in_seconds.unwrap_or(300); + let timeout_in_seconds = task.timeout_in_seconds.unwrap_or(DEFAULT_TASK_TIMEOUT_SECONDS); let task_clone = task.clone(); let timeout_duration = Duration::from_secs(timeout_in_seconds); @@ -75,17 +77,19 @@ async fn get_task_result( } fn build_command(task: &Task) -> Result<(Command, String), String> { + let task_error = |field: &str| format!("Task {}: Missing {}", task.id, field); + let mut output_identifier = task.id.clone(); let mut command = if task.task_type == "sub_recipe" { let sub_recipe_name = task .get_sub_recipe_name() - .ok_or("Missing sub_recipe name")?; + .ok_or_else(|| task_error("sub_recipe name"))?; let path = task .get_sub_recipe_path() - .ok_or("Missing sub_recipe path")?; + .ok_or_else(|| task_error("sub_recipe path"))?; let command_parameters = task .get_command_parameters() - .ok_or("Missing command_parameters")?; + .ok_or_else(|| task_error("command_parameters"))?; output_identifier = format!("sub-recipe {}", sub_recipe_name); let mut cmd = Command::new("goose"); @@ -101,7 +105,7 @@ fn build_command(task: &Task) -> Result<(Command, String), String> { } else { let text = task .get_text_instruction() - .ok_or("Missing text_instruction")?; + .ok_or_else(|| task_error("text_instruction"))?; let mut cmd = Command::new("goose"); cmd.arg("run").arg("--text").arg(text); cmd @@ -172,13 +176,29 @@ fn spawn_output_reader( .send_live_output(&task_id, &line) .await; } else { - eprintln!("[stderr for {}] {}", output_identifier, line); + tracing::warn!("Task stderr [{}]: {}", output_identifier, line); } } buffer }) } +fn extract_json_from_line(line: &str) -> Option { + let start = line.find('{')?; + let end = line.rfind('}')?; + + if start >= end { + return None; + } + + let potential_json = &line[start..=end]; + if serde_json::from_str::(potential_json).is_ok() { + Some(potential_json.to_string()) + } else { + None + } +} + fn process_output(stdout_output: String) -> Result { let last_line = stdout_output .lines() @@ -186,14 +206,9 @@ fn process_output(stdout_output: String) -> Result { .next_back() .unwrap_or(""); - if let (Some(start), Some(end)) = (last_line.find('{'), last_line.rfind('}')) { - if start < end { - let potential_json = &last_line[start..=end]; - - if serde_json::from_str::(potential_json).is_ok() { - return Ok(Value::String(potential_json.to_string())); - } - } + if let Some(json_string) = extract_json_from_line(last_line) { + Ok(Value::String(json_string)) + } else { + Ok(Value::String(stdout_output)) } - Ok(Value::String(stdout_output)) } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index 4184eb65ec36..e10c5adb889e 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -16,11 +16,9 @@ pub struct Task { impl Task { pub fn get_sub_recipe(&self) -> Option<&Map> { - if self.task_type == "sub_recipe" { - self.payload.get("sub_recipe").and_then(|sr| sr.as_object()) - } else { - None - } + (self.task_type == "sub_recipe") + .then(|| self.payload.get("sub_recipe")?.as_object()) + .flatten() } pub fn get_command_parameters(&self) -> Option<&Map> { @@ -124,12 +122,15 @@ impl Default for Config { } } +const DEFAULT_MAX_WORKERS: usize = 10; +const DEFAULT_INITIAL_WORKERS: usize = 2; + fn default_max_workers() -> usize { - 10 + DEFAULT_MAX_WORKERS } fn default_initial_workers() -> usize { - 2 + DEFAULT_INITIAL_WORKERS } #[derive(Debug, Serialize)] diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index bf806a50d182..e14dbb2f20bb 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -10,9 +10,6 @@ pub fn get_task_name(task_info: &TaskInfo) -> &str { .unwrap_or(&task_info.task.id) } -pub fn get_command_parameters(task_info: &TaskInfo) -> Option<&Map> { - task_info.task.get_command_parameters() -} pub fn count_by_status(tasks: &HashMap) -> (usize, usize, usize, usize, usize) { let total = tasks.len(); diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs index 0e0500c490ec..618026a7f003 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -3,6 +3,17 @@ use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_ use serde_json::json; use std::collections::HashMap; +fn create_task_info_with_defaults(task: Task, status: TaskStatus) -> TaskInfo { + TaskInfo { + task, + status, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + } +} + mod test_get_task_name { use super::*; @@ -20,14 +31,7 @@ mod test_get_task_name { }), }; - let task_info = TaskInfo { - task: sub_recipe_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; + let task_info = create_task_info_with_defaults(sub_recipe_task, TaskStatus::Pending); assert_eq!(get_task_name(&task_info), "my_recipe"); } @@ -41,14 +45,7 @@ mod test_get_task_name { payload: json!({"text_instruction": "do something"}), }; - let task_info = TaskInfo { - task: text_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; + let task_info = create_task_info_with_defaults(text_task, TaskStatus::Pending); assert_eq!(get_task_name(&task_info), "task_2"); } @@ -67,14 +64,7 @@ mod test_get_task_name { }), }; - let task_info = TaskInfo { - task: malformed_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; + let task_info = create_task_info_with_defaults(malformed_task, TaskStatus::Pending); assert_eq!(get_task_name(&task_info), "task_3"); } @@ -88,14 +78,7 @@ mod test_get_task_name { payload: json!({}), // missing "sub_recipe" field }; - let task_info = TaskInfo { - task: malformed_task, - status: TaskStatus::Pending, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - }; + let task_info = create_task_info_with_defaults(malformed_task, TaskStatus::Pending); assert_eq!(get_task_name(&task_info), "task_4"); } @@ -105,19 +88,13 @@ mod count_by_status { use super::*; fn create_test_task(id: &str, status: TaskStatus) -> TaskInfo { - TaskInfo { - task: Task { - id: id.to_string(), - task_type: "test".to_string(), - timeout_in_seconds: None, - payload: json!({}), - }, - status, - start_time: None, - end_time: None, - result: None, - current_output: String::new(), - } + let task = Task { + id: id.to_string(), + task_type: "test".to_string(), + timeout_in_seconds: None, + payload: json!({}), + }; + create_task_info_with_defaults(task, status) } #[test] diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index f891595456b5..35e9f6d22219 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -21,7 +21,7 @@ async fn worker_loop(state: Arc, _worker_id: usize) { let result = process_task(&task, state.task_execution_tracker.clone()).await; if let Err(e) = state.result_sender.send(result).await { - eprintln!("Worker failed to send result: {}", e); + tracing::error!("Worker failed to send result: {}", e); break; } } From c14da819a16a105b70f4387a49f00c10ab4722d6 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 16:20:32 +1000 Subject: [PATCH 19/34] refactored parsing recipe --- .../agents/recipe_tools/sub_recipe_tools.rs | 6 ++--- .../sub_recipe_execution_tool/utils/mod.rs | 1 - crates/goose/src/recipe/mod.rs | 24 ++++++------------- 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index dac0825beda2..db4b3fed6c08 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -39,12 +39,12 @@ pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { ) } -fn extract_task_parameters(params: &Value) -> &Vec { - let empty_vec = vec![]; +fn extract_task_parameters(params: &Value) -> Vec { params .get("task_parameters") .and_then(|v| v.as_array()) - .unwrap_or(&empty_vec) + .cloned() + .unwrap_or_default() } fn create_tasks_from_params( diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index e14dbb2f20bb..c2e39b15ade3 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -1,4 +1,3 @@ -use serde_json::{Map, Value}; use std::collections::HashMap; use crate::agents::sub_recipe_execution_tool::types::{TaskInfo, TaskStatus}; diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 525433221ef2..144c7831b32c 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -281,23 +281,13 @@ impl Recipe { } } pub fn from_content(content: &str) -> Result { - if let Ok(json_value) = serde_json::from_str::(content) { - if let Some(nested_recipe) = json_value.get("recipe") { - Ok(serde_json::from_value(nested_recipe.clone())?) - } else { - Ok(serde_json::from_str(content)?) - } - } else if let Ok(yaml_value) = serde_yaml::from_str::(content) { - if let Some(nested_recipe) = yaml_value.get("recipe") { - Ok(serde_yaml::from_value(nested_recipe.clone())?) - } else { - Ok(serde_yaml::from_str(content)?) - } - } else { - Err(anyhow::anyhow!( - "Unsupported format. Expected JSON or YAML." - )) - } + let yaml_value = serde_yaml::from_str::(content) + .map_err(|_| anyhow::anyhow!("Unsupported format. Expected JSON or YAML."))?; + + // Check if there's a nested "recipe" key, otherwise use the root + let recipe_value = yaml_value.get("recipe").unwrap_or(&yaml_value); + + Ok(serde_yaml::from_value(recipe_value.clone())?) } } From a01d7127bf9c69013ba5f77b291223fe33f5d569 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 17:39:48 +1000 Subject: [PATCH 20/34] created taskExecutionNotificationEvent --- .../src/session/task_execution_display.rs | 319 ++++++-------- .../agents/recipe_tools/param_utils/tests.rs | 415 +++++++++--------- .../agents/recipe_tools/sub_recipe_tools.rs | 2 +- .../agents/sub_recipe_execution_tool/lib.rs | 18 +- .../agents/sub_recipe_execution_tool/mod.rs | 1 + .../notification_events.rs | 204 +++++++++ .../sub_recipe_execute_task_tool.rs | 7 +- .../task_execution_tracker.rs | 129 +++--- .../agents/sub_recipe_execution_tool/tasks.rs | 10 +- .../agents/sub_recipe_execution_tool/types.rs | 19 + .../sub_recipe_execution_tool/utils/mod.rs | 1 - crates/goose/src/recipe/mod.rs | 24 +- 12 files changed, 668 insertions(+), 481 deletions(-) create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs diff --git a/crates/goose-cli/src/session/task_execution_display.rs b/crates/goose-cli/src/session/task_execution_display.rs index 547c41b09e88..f8bad819f4d7 100644 --- a/crates/goose-cli/src/session/task_execution_display.rs +++ b/crates/goose-cli/src/session/task_execution_display.rs @@ -1,3 +1,7 @@ +use goose::agents::sub_recipe_execution_tool::lib::TaskStatus; +use goose::agents::sub_recipe_execution_tool::notification_events::{ + TaskExecutionNotificationEvent, TaskInfo, +}; use serde_json::Value; use std::sync::atomic::{AtomicBool, Ordering}; @@ -9,130 +13,13 @@ pub const TASK_EXECUTION_NOTIFICATION_TYPE: &str = "task_execution"; static INITIAL_SHOWN: AtomicBool = AtomicBool::new(false); -pub fn format_tasks_update(data: &Value) -> String { - let mut display = String::new(); - - if !INITIAL_SHOWN.swap(true, Ordering::SeqCst) { - display.push_str(CLEAR_SCREEN); - display.push_str("🎯 Task Execution Dashboard\n"); - display.push_str("═══════════════════════════\n\n"); - } else { - display.push_str(MOVE_TO_PROGRESS_LINE); - } - - if let Some(stats) = data.get("stats") { - let total = stats.get("total").and_then(|v| v.as_u64()).unwrap_or(0); - let pending = stats.get("pending").and_then(|v| v.as_u64()).unwrap_or(0); - let running = stats.get("running").and_then(|v| v.as_u64()).unwrap_or(0); - let completed = stats.get("completed").and_then(|v| v.as_u64()).unwrap_or(0); - let failed = stats.get("failed").and_then(|v| v.as_u64()).unwrap_or(0); - - display.push_str(&format!( - "📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed", - total, pending, running, completed, failed - )); - display.push_str(&format!("{}\n\n", CLEAR_TO_EOL)); - } - - if let Some(tasks) = data.get("tasks").and_then(|t| t.as_array()) { - let mut sorted_tasks: Vec<_> = tasks.iter().collect(); - sorted_tasks.sort_by_key(|task| task.get("id").and_then(|v| v.as_str()).unwrap_or("")); - - for task in sorted_tasks { - display.push_str(&format_task_from_json(task)); - } - } - - display.push_str(CLEAR_BELOW); - display -} - -fn format_task_from_json(task: &Value) -> String { - let mut task_display = String::new(); - - let id = task.get("id").and_then(|v| v.as_str()).unwrap_or("unknown"); - let status = task - .get("status") - .and_then(|v| v.as_str()) - .unwrap_or("unknown"); - let task_type = task - .get("task_type") - .and_then(|v| v.as_str()) - .unwrap_or("task"); - let task_name = task.get("task_name").and_then(|v| v.as_str()).unwrap_or(id); - let task_metadata = task - .get("task_metadata") - .and_then(|v| v.as_str()) - .unwrap_or(""); - - let status_icon = match status { - "Pending" => "⏳", - "Running" => "🏃", - "Completed" => "✅", - "Failed" => "❌", - _ => "◯", - }; - - task_display.push_str(&format!( - "{} {} ({}){}\n", - status_icon, task_name, task_type, CLEAR_TO_EOL - )); - - if !task_metadata.is_empty() { - task_display.push_str(&format!( - " 📋 Parameters: {}{}\n", - task_metadata, CLEAR_TO_EOL - )); - } - - if let Some(duration_secs) = task.get("duration_secs").and_then(|v| v.as_f64()) { - task_display.push_str(&format!(" ⏱️ {:.1}s{}\n", duration_secs, CLEAR_TO_EOL)); - } - - if status == "Running" { - if let Some(current_output) = task.get("current_output").and_then(|v| v.as_str()) { - if !current_output.trim().is_empty() { - let processed_output = process_output_for_display(current_output); - if !processed_output.is_empty() { - task_display.push_str(&format!(" 💬 {}{}\n", processed_output, CLEAR_TO_EOL)); - } - } - } - } - - if status == "Completed" { - if let Some(result_data) = task.get("result_data") { - let result_preview = format_result_data_for_display(result_data); - if !result_preview.is_empty() { - task_display.push_str(&format!(" 📄 {}{}\n", result_preview, CLEAR_TO_EOL)); - } - } - } - - if status == "Failed" { - if let Some(error) = task.get("error").and_then(|v| v.as_str()) { - let error_preview = truncate_with_ellipsis(error, 80); - task_display.push_str(&format!( - " ⚠️ {}{}\n", - error_preview.replace('\n', " "), - CLEAR_TO_EOL - )); - } - } - - task_display.push_str(&format!("{}\n", CLEAR_TO_EOL)); - task_display -} - fn format_result_data_for_display(result_data: &Value) -> String { match result_data { Value::String(s) => strip_ansi_codes(s), Value::Object(obj) => { - // Handle specific result formats if let Some(partial_output) = obj.get("partial_output").and_then(|v| v.as_str()) { format!("Partial output: {}", partial_output) } else { - // Generic object display serde_json::to_string_pretty(obj).unwrap_or_default() } } @@ -190,84 +77,154 @@ fn strip_ansi_codes(text: &str) -> String { result } -pub fn format_tasks_complete(data: &Value) -> String { - let mut summary = String::new(); - summary.push_str("Execution Complete!\n"); - summary.push_str("═══════════════════════\n"); +pub fn format_task_execution_notification( + data: &Value, +) -> Option<(String, Option, Option)> { + if let Ok(event) = serde_json::from_value::(data.clone()) { + return Some(match event { + TaskExecutionNotificationEvent::LineOutput { output, .. } => ( + format!("{}\n", output), + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ), + TaskExecutionNotificationEvent::TasksUpdate { .. } => { + let formatted_display = format_tasks_update_from_event(&event); + ( + formatted_display, + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ) + } + TaskExecutionNotificationEvent::TasksComplete { .. } => { + let formatted_summary = format_tasks_complete_from_event(&event); + ( + formatted_summary, + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ) + } + }); + } + None +} + +fn format_tasks_update_from_event(event: &TaskExecutionNotificationEvent) -> String { + if let TaskExecutionNotificationEvent::TasksUpdate { stats, tasks } = event { + let mut display = String::new(); + + if !INITIAL_SHOWN.swap(true, Ordering::SeqCst) { + display.push_str(CLEAR_SCREEN); + display.push_str("🎯 Task Execution Dashboard\n"); + display.push_str("═══════════════════════════\n\n"); + } else { + display.push_str(MOVE_TO_PROGRESS_LINE); + } + + display.push_str(&format!( + "📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed", + stats.total, stats.pending, stats.running, stats.completed, stats.failed + )); + display.push_str(&format!("{}\n\n", CLEAR_TO_EOL)); - if let Some(stats) = data.get("stats") { - let total = stats.get("total").and_then(|v| v.as_u64()).unwrap_or(0); - let completed = stats.get("completed").and_then(|v| v.as_u64()).unwrap_or(0); - let failed = stats.get("failed").and_then(|v| v.as_u64()).unwrap_or(0); - let success_rate = stats - .get("success_rate") - .and_then(|v| v.as_f64()) - .unwrap_or(0.0); + let mut sorted_tasks = tasks.clone(); + sorted_tasks.sort_by(|a, b| a.id.cmp(&b.id)); - summary.push_str(&format!("Total Tasks: {}\n", total)); - summary.push_str(&format!("✅ Completed: {}\n", completed)); - summary.push_str(&format!("❌ Failed: {}\n", failed)); - summary.push_str(&format!("📈 Success Rate: {:.1}%\n", success_rate)); + for task in sorted_tasks { + display.push_str(&format_task_from_struct(&task)); + } + + display.push_str(CLEAR_BELOW); + display + } else { + String::new() } +} + +fn format_tasks_complete_from_event(event: &TaskExecutionNotificationEvent) -> String { + if let TaskExecutionNotificationEvent::TasksComplete { + stats, + failed_tasks, + } = event + { + let mut summary = String::new(); + summary.push_str("Execution Complete!\n"); + summary.push_str("═══════════════════════\n"); + + summary.push_str(&format!("Total Tasks: {}\n", stats.total)); + summary.push_str(&format!("✅ Completed: {}\n", stats.completed)); + summary.push_str(&format!("❌ Failed: {}\n", stats.failed)); + summary.push_str(&format!("📈 Success Rate: {:.1}%\n", stats.success_rate)); - if let Some(failed_tasks) = data.get("failed_tasks").and_then(|t| t.as_array()) { if !failed_tasks.is_empty() { summary.push_str("\n❌ Failed Tasks:\n"); for task in failed_tasks { - let name = task - .get("name") - .and_then(|v| v.as_str()) - .unwrap_or("Unknown"); - summary.push_str(&format!(" • {}\n", name)); - if let Some(error) = task.get("error").and_then(|v| v.as_str()) { + summary.push_str(&format!(" • {}\n", task.name)); + if let Some(error) = &task.error { summary.push_str(&format!(" Error: {}\n", error)); } } } - } - summary.push_str("\n📝 Generating summary...\n"); - summary + summary.push_str("\n📝 Generating summary...\n"); + summary + } else { + String::new() + } } -pub fn format_task_execution_notification( - data: &Value, -) -> Option<(String, Option, Option)> { - if let Value::Object(o) = data { - if o.get("type").and_then(|t| t.as_str()) == Some(TASK_EXECUTION_NOTIFICATION_TYPE) { - return Some(match o.get("subtype").and_then(|t| t.as_str()) { - Some("line_output") => { - if let Some(Value::String(line_output)) = o.get("output") { - ( - format!("{}\n", line_output), - None, - Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), - ) - } else { - (data.to_string(), None, None) - } - } - Some("tasks_update") => { - let data_value = Value::Object(o.clone()); - let formatted_display = format_tasks_update(&data_value); - ( - formatted_display, - None, - Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), - ) - } - Some("tasks_complete") => { - let data_value = Value::Object(o.clone()); - let formatted_summary = format_tasks_complete(&data_value); - ( - formatted_summary, - None, - Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), - ) - } - _ => (data.to_string(), None, None), - }); +fn format_task_from_struct(task: &TaskInfo) -> String { + let mut task_display = String::new(); + + let status_icon = match task.status { + TaskStatus::Pending => "⏳", + TaskStatus::Running => "🏃", + TaskStatus::Completed => "✅", + TaskStatus::Failed => "❌", + }; + + task_display.push_str(&format!( + "{} {} ({}){}\n", + status_icon, task.task_name, task.task_type, CLEAR_TO_EOL + )); + + if !task.task_metadata.is_empty() { + task_display.push_str(&format!( + " 📋 Parameters: {}{}\n", + task.task_metadata, CLEAR_TO_EOL + )); + } + + if let Some(duration_secs) = task.duration_secs { + task_display.push_str(&format!(" ⏱️ {:.1}s{}\n", duration_secs, CLEAR_TO_EOL)); + } + + if matches!(task.status, TaskStatus::Running) && !task.current_output.trim().is_empty() { + let processed_output = process_output_for_display(&task.current_output); + if !processed_output.is_empty() { + task_display.push_str(&format!(" 💬 {}{}\n", processed_output, CLEAR_TO_EOL)); } } - None + + if matches!(task.status, TaskStatus::Completed) { + if let Some(result_data) = &task.result_data { + let result_preview = format_result_data_for_display(result_data); + if !result_preview.is_empty() { + task_display.push_str(&format!(" 📄 {}{}\n", result_preview, CLEAR_TO_EOL)); + } + } + } + + if matches!(task.status, TaskStatus::Failed) { + if let Some(error) = &task.error { + let error_preview = truncate_with_ellipsis(error, 80); + task_display.push_str(&format!( + " ⚠️ {}{}\n", + error_preview.replace('\n', " "), + CLEAR_TO_EOL + )); + } + } + + task_display.push_str(&format!("{}\n", CLEAR_TO_EOL)); + task_display } diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs index 0a52da64f428..9cbec0c7b828 100644 --- a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -5,229 +5,224 @@ use serde_json::json; use crate::agents::recipe_tools::param_utils::prepare_command_params; - fn setup_default_sub_recipe() -> SubRecipe { - let sub_recipe = SubRecipe { - name: "test_sub_recipe".to_string(), - path: "test_sub_recipe.yaml".to_string(), - timeout_in_seconds: None, - values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), - executions: None, - }; - sub_recipe +fn setup_default_sub_recipe() -> SubRecipe { + let sub_recipe = SubRecipe { + name: "test_sub_recipe".to_string(), + path: "test_sub_recipe.yaml".to_string(), + timeout_in_seconds: None, + values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), + executions: None, + }; + sub_recipe +} + +fn create_execution_values(key: &str, values: Vec) -> Execution { + let runs = values + .iter() + .map(|value| ExecutionRun { + values: Some(HashMap::from([(key.to_string(), value.to_string())])), + }) + .collect(); + Execution { + parallel: true, + runs: Some(runs), } +} - fn create_execution_values(key: &str, values: Vec) -> Execution { - let runs = values - .iter() - .map(|value| ExecutionRun { - values: Some(HashMap::from([(key.to_string(), value.to_string())])), - }) - .collect(); - Execution { - parallel: true, - runs: Some(runs), +mod prepare_command_params_tests { + use super::*; + + mod without_execution_runs { + use super::*; + + #[test] + fn test_return_command_param() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "value2".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_command_param_when_value_override_passed_param_value() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "different_value".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_empty_command_param() { + let parameter_array = vec![]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!(result.len(), 0); } } - mod prepare_command_params_tests { + mod with_execution_runs { use super::*; - mod without_execution_runs { - use super::*; - - #[test] - fn test_return_command_param() { - let parameter_array = vec![json!(HashMap::from([( - "key2".to_string(), - "value2".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()) - ]),], - result - ); - } - - #[test] - fn test_return_command_param_when_value_override_passed_param_value() { - let parameter_array = vec![json!(HashMap::from([( - "key2".to_string(), - "different_value".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = Some(HashMap::from([ + #[test] + fn test_return_command_param() { + let parameter_array = vec![json!(HashMap::from([( + "key3".to_string(), + "value3".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = + Some(create_execution_values("key2", vec!["value2".to_string()])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ ("key1".to_string(), "value1".to_string()), ("key2".to_string(), "value2".to_string()), - ])); + ("key3".to_string(), "value3".to_string()) + ]),], + result + ); + } - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![HashMap::from([ + #[test] + fn test_return_command_param_when_all_values_from_tool_call_parameters() { + let parameter_array = vec![ + json!(HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()) + ])), + json!(HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()) + ])), + ]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + sub_recipe.executions = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ]), + ], + result + ); + } + + #[test] + fn test_return_command_param_when_all_from_values_in_sub_recipe() { + let parameter_array = vec![]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string(), "key2_value2".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ]), + HashMap::from([ ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()) - ]),], - result - ); - } - - #[test] - fn test_return_empty_command_param() { - let parameter_array = vec![]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = None; - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!(result.len(), 0); - } + ("key2".to_string(), "key2_value2".to_string()), + ("key3".to_string(), "value3".to_string()), + ]) + ], + result + ); } - mod with_execution_runs { - use super::*; - - #[test] - fn test_return_command_param() { - let parameter_array = vec![json!(HashMap::from([( - "key3".to_string(), - "value3".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - sub_recipe.executions = - Some(create_execution_values("key2", vec!["value2".to_string()])); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![HashMap::from([ + #[test] + fn test_return_command_param_when_tool_call_parameters_has_one_item_and_execution_runs_has_multiple_items( + ) { + let parameter_array = vec![json!(HashMap::from([( + "key3".to_string(), + "value3".to_string() + ),]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string(), "key2_value2".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()), - ("key3".to_string(), "value3".to_string()) - ]),], - result - ); - } - - #[test] - fn test_return_command_param_when_all_values_from_tool_call_parameters() { - let parameter_array = vec![ - json!(HashMap::from([ - ("key1".to_string(), "key1_value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()) - ])), - json!(HashMap::from([ - ("key1".to_string(), "key1_value2".to_string()), - ("key2".to_string(), "key2_value2".to_string()) - ])), - ]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = None; - sub_recipe.executions = None; - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![ - HashMap::from([ - ("key1".to_string(), "key1_value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()), - ]), - HashMap::from([ - ("key1".to_string(), "key1_value2".to_string()), - ("key2".to_string(), "key2_value2".to_string()), - ]), - ], - result - ); - } - - #[test] - fn test_return_command_param_when_all_from_values_in_sub_recipe() { - let parameter_array = vec![]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = Some(HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key3".to_string(), "value3".to_string()), - ])); - sub_recipe.executions = Some(create_execution_values( - "key2", - vec!["key2_value1".to_string(), "key2_value2".to_string()], - )); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![ - HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()), - ("key3".to_string(), "value3".to_string()), - ]), - HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value2".to_string()), - ("key3".to_string(), "value3".to_string()), - ]) - ], - result - ); - } - - #[test] - fn test_return_command_param_when_tool_call_parameters_has_one_item_and_execution_runs_has_multiple_items( - ) { - let parameter_array = vec![json!(HashMap::from([( - "key3".to_string(), - "value3".to_string() - ),]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - sub_recipe.executions = Some(create_execution_values( - "key2", - vec!["key2_value1".to_string(), "key2_value2".to_string()], - )); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![ - HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()), - ("key3".to_string(), "value3".to_string()), - ]), - HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value2".to_string()), - ("key3".to_string(), "value3".to_string()), - ]) - ], - result - ); - } - - #[test] - fn test_throw_error_when_execution_runs_value_length_not_match_with_tool_call_parameters( - ) { - let parameter_array = vec![ - json!(HashMap::from([("key3".to_string(), "value3".to_string())])), - json!(HashMap::from([("key4".to_string(), "value4".to_string())])), - ]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - sub_recipe.executions = Some(create_execution_values( - "key2", - vec!["key2_value1".to_string()], - )); - - let result = prepare_command_params(&sub_recipe, parameter_array); - - assert!(result.is_err()); - } + ("key2".to_string(), "key2_value1".to_string()), + ("key3".to_string(), "value3".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ("key3".to_string(), "value3".to_string()), + ]) + ], + result + ); + } + + #[test] + fn test_throw_error_when_execution_runs_value_length_not_match_with_tool_call_parameters() { + let parameter_array = vec![ + json!(HashMap::from([("key3".to_string(), "value3".to_string())])), + json!(HashMap::from([("key4".to_string(), "value4".to_string())])), + ]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + sub_recipe.executions = Some(create_execution_values( + "key2", + vec!["key2_value1".to_string()], + )); + + let result = prepare_command_params(&sub_recipe, parameter_array); + + assert!(result.is_err()); } } +} diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index db4b3fed6c08..fd0edaeaecd1 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -77,7 +77,7 @@ fn get_execution_mode(sub_recipe: &SubRecipe) -> &'static str { .as_ref() .map(|e| e.parallel) .unwrap_or(false); - + if is_parallel { EXECUTION_MODE_PARALLEL } else { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs index e1e20d772969..4cded2c71000 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs @@ -2,7 +2,8 @@ use crate::agents::sub_recipe_execution_tool::executor::{ execute_single_task, execute_tasks_in_parallel, }; pub use crate::agents::sub_recipe_execution_tool::types::{ - Config, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, + Config, ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, + TaskStatus, }; use mcp_core::protocol::JsonRpcMessage; @@ -11,7 +12,7 @@ use tokio::sync::mpsc; pub async fn execute_tasks( input: Value, - execution_mode: &str, + execution_mode: ExecutionMode, notifier: mpsc::Sender, ) -> Result { let tasks: Vec = @@ -27,7 +28,7 @@ pub async fn execute_tasks( let task_count = tasks.len(); match execution_mode { - "sequential" => { + ExecutionMode::Sequential => { if task_count == 1 { let response = execute_single_task(&tasks[0], notifier).await; handle_response(response) @@ -35,12 +36,11 @@ pub async fn execute_tasks( Err("Sequential execution mode requires exactly one task".to_string()) } } - "parallel" => { + ExecutionMode::Parallel => { let response: ExecutionResponse = execute_tasks_in_parallel(tasks, config, notifier).await; handle_response(response) } - _ => Err("Invalid execution mode".to_string()), } } @@ -48,7 +48,7 @@ fn extract_failed_tasks(results: &[TaskResult]) -> Vec { results .iter() .filter(|r| matches!(r.status, TaskStatus::Failed)) - .map(|r| format_failed_task_error(r)) + .map(format_failed_task_error) .collect() } @@ -71,7 +71,11 @@ fn format_failed_task_error(result: &TaskResult) -> String { ) } -fn format_error_summary(failed_count: usize, total_count: usize, failed_tasks: Vec) -> String { +fn format_error_summary( + failed_count: usize, + total_count: usize, + failed_tasks: Vec, +) -> String { format!( "{}/{} tasks failed:\n{}", failed_count, diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs index 03568dc53197..267caab8f115 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs @@ -1,5 +1,6 @@ mod executor; pub mod lib; +pub mod notification_events; pub mod sub_recipe_execute_task_tool; mod task_execution_tracker; mod tasks; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs b/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs new file mode 100644 index 000000000000..97a4576b2d99 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs @@ -0,0 +1,204 @@ +use crate::agents::sub_recipe_execution_tool::types::TaskStatus; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "subtype")] +pub enum TaskExecutionNotificationEvent { + #[serde(rename = "line_output")] + LineOutput { task_id: String, output: String }, + #[serde(rename = "tasks_update")] + TasksUpdate { + stats: TaskExecutionStats, + tasks: Vec, + }, + #[serde(rename = "tasks_complete")] + TasksComplete { + stats: TaskCompletionStats, + failed_tasks: Vec, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskExecutionStats { + pub total: usize, + pub pending: usize, + pub running: usize, + pub completed: usize, + pub failed: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskCompletionStats { + pub total: usize, + pub completed: usize, + pub failed: usize, + pub success_rate: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskInfo { + pub id: String, + pub status: TaskStatus, + pub duration_secs: Option, + pub current_output: String, + pub task_type: String, + pub task_name: String, + pub task_metadata: String, + pub error: Option, + pub result_data: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FailedTaskInfo { + pub id: String, + pub name: String, + pub error: Option, +} + +impl TaskExecutionNotificationEvent { + pub fn line_output(task_id: String, output: String) -> Self { + Self::LineOutput { task_id, output } + } + + pub fn tasks_update(stats: TaskExecutionStats, tasks: Vec) -> Self { + Self::TasksUpdate { stats, tasks } + } + + pub fn tasks_complete(stats: TaskCompletionStats, failed_tasks: Vec) -> Self { + Self::TasksComplete { + stats, + failed_tasks, + } + } + + /// Convert event to JSON format for MCP notification + pub fn to_notification_data(&self) -> serde_json::Value { + let mut event_data = serde_json::to_value(self).expect("Failed to serialize event"); + + // Add the type field at the root level + if let serde_json::Value::Object(ref mut map) = event_data { + map.insert( + "type".to_string(), + serde_json::Value::String("task_execution".to_string()), + ); + } + + event_data + } +} + +impl TaskExecutionStats { + pub fn new( + total: usize, + pending: usize, + running: usize, + completed: usize, + failed: usize, + ) -> Self { + Self { + total, + pending, + running, + completed, + failed, + } + } +} + +impl TaskCompletionStats { + pub fn new(total: usize, completed: usize, failed: usize) -> Self { + let success_rate = if total > 0 { + (completed as f64 / total as f64) * 100.0 + } else { + 0.0 + }; + + Self { + total, + completed, + failed, + success_rate, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_line_output_event_serialization() { + let event = TaskExecutionNotificationEvent::line_output( + "task-1".to_string(), + "Hello World".to_string(), + ); + + let notification_data = event.to_notification_data(); + assert_eq!(notification_data["type"], "task_execution"); + assert_eq!(notification_data["subtype"], "line_output"); + assert_eq!(notification_data["task_id"], "task-1"); + assert_eq!(notification_data["output"], "Hello World"); + } + + #[test] + fn test_tasks_update_event_serialization() { + let stats = TaskExecutionStats::new(5, 2, 1, 1, 1); + let tasks = vec![TaskInfo { + id: "task-1".to_string(), + status: TaskStatus::Running, + duration_secs: Some(1.5), + current_output: "Processing...".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "test-task".to_string(), + task_metadata: "param=value".to_string(), + error: None, + result_data: None, + }]; + + let event = TaskExecutionNotificationEvent::tasks_update(stats, tasks); + let notification_data = event.to_notification_data(); + + assert_eq!(notification_data["type"], "task_execution"); + assert_eq!(notification_data["subtype"], "tasks_update"); + assert_eq!(notification_data["stats"]["total"], 5); + assert_eq!(notification_data["tasks"].as_array().unwrap().len(), 1); + } + + #[test] + fn test_event_roundtrip_serialization() { + let original_event = TaskExecutionNotificationEvent::line_output( + "task-1".to_string(), + "Test output".to_string(), + ); + + // Serialize to JSON + let json_data = original_event.to_notification_data(); + + // Deserialize back to event (excluding the type field) + let mut event_data = json_data.clone(); + if let serde_json::Value::Object(ref mut map) = event_data { + map.remove("type"); + } + + let deserialized_event: TaskExecutionNotificationEvent = + serde_json::from_value(event_data).expect("Failed to deserialize"); + + match (original_event, deserialized_event) { + ( + TaskExecutionNotificationEvent::LineOutput { + task_id: id1, + output: out1, + }, + TaskExecutionNotificationEvent::LineOutput { + task_id: id2, + output: out2, + }, + ) => { + assert_eq!(id1, id2); + assert_eq!(out1, out2); + } + _ => panic!("Event types don't match after roundtrip"), + } + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index f1ca709cb318..80a537393e5b 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -2,7 +2,8 @@ use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; use serde_json::Value; use crate::agents::{ - sub_recipe_execution_tool::lib::execute_tasks, tool_execution::ToolCallResult, + sub_recipe_execution_tool::lib::execute_tasks, sub_recipe_execution_tool::types::ExecutionMode, + tool_execution::ToolCallResult, }; use mcp_core::protocol::JsonRpcMessage; use tokio::sync::mpsc; @@ -128,8 +129,8 @@ pub async fn run_tasks(execute_data: Value) -> ToolCallResult { let execute_data_clone = execute_data.clone(); let execution_mode = execute_data_clone .get("execution_mode") - .and_then(|v| v.as_str()) - .unwrap_or("sequential"); + .and_then(|v| serde_json::from_value::(v.clone()).ok()) + .unwrap_or_default(); match execute_tasks(execute_data, execution_mode, notification_tx).await { Ok(result) => { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs index 96176fc77930..639e4580b843 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -5,6 +5,10 @@ use std::sync::Arc; use tokio::sync::{mpsc, RwLock}; use tokio::time::{sleep, Duration, Instant}; +use crate::agents::sub_recipe_execution_tool::notification_events::{ + FailedTaskInfo, TaskCompletionStats, TaskExecutionNotificationEvent, TaskExecutionStats, + TaskInfo as EventTaskInfo, +}; use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskResult, TaskStatus}; use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; use serde_json::Value; @@ -110,21 +114,21 @@ impl TaskExecutionTracker { pub async fn send_live_output(&self, task_id: &str, line: &str) { match self.display_mode { DisplayMode::SingleTaskOutput => { - // Send raw output data - let subscriber format it - if let Err(e) = self - .notifier - .try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: "notifications/message".to_string(), - params: Some(json!({ - "data": { - "type": "task_execution", - "subtype": "line_output", - "task_id": task_id, - "output": line - } - })), - })) { + let event = TaskExecutionNotificationEvent::line_output( + task_id.to_string(), + line.to_string(), + ); + + if let Err(e) = + self.notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": event.to_notification_data() + })), + })) + { tracing::warn!("Failed to send live output notification: {}", e); } } @@ -160,45 +164,44 @@ impl TaskExecutionTracker { let task_list: Vec<_> = tasks.values().collect(); let (total, pending, running, completed, failed) = count_by_status(&tasks); + let stats = TaskExecutionStats::new(total, pending, running, completed, failed); + + let event_tasks: Vec = task_list + .iter() + .map(|task_info| { + let now = Instant::now(); + EventTaskInfo { + id: task_info.task.id.clone(), + status: task_info.status.clone(), + duration_secs: task_info.start_time.map(|start| { + if let Some(end) = task_info.end_time { + end.duration_since(start).as_secs_f64() + } else { + now.duration_since(start).as_secs_f64() + } + }), + current_output: task_info.current_output.clone(), + task_type: task_info.task.task_type.clone(), + task_name: get_task_name(task_info).to_string(), + task_metadata: format_task_metadata(task_info), + error: task_info.error().cloned(), + result_data: task_info.data().cloned(), + } + }) + .collect(); + + let event = TaskExecutionNotificationEvent::tasks_update(stats, event_tasks); + if let Err(e) = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { jsonrpc: "2.0".to_string(), method: "notifications/message".to_string(), params: Some(json!({ - "data": { - "type": "task_execution", - "subtype": "tasks_update", - "stats": { - "total": total, - "pending": pending, - "running": running, - "completed": completed, - "failed": failed - }, - "tasks": task_list.iter().map(|task_info| { - let now = Instant::now(); - json!({ - "id": task_info.task.id, - "status": task_info.status, - "duration_secs": task_info.start_time.map(|start| { - if let Some(end) = task_info.end_time { - end.duration_since(start).as_secs_f64() - } else { - now.duration_since(start).as_secs_f64() - } - }), - "current_output": task_info.current_output, - "task_type": task_info.task.task_type, - "task_name": get_task_name(task_info), - "task_metadata": format_task_metadata(task_info), - "error": task_info.error(), - "result_data": task_info.data() - }) - }).collect::>() - } + "data": event.to_notification_data() })), - })) { + })) + { tracing::warn!("Failed to send tasks update notification: {}", e); } } @@ -236,38 +239,30 @@ impl TaskExecutionTracker { let tasks = self.tasks.read().await; let (total, _, _, completed, failed) = count_by_status(&tasks); - // Send structured summary data only - let failed_tasks: Vec<_> = tasks + let stats = TaskCompletionStats::new(total, completed, failed); + + let failed_tasks: Vec = tasks .values() .filter(|task_info| matches!(task_info.status, TaskStatus::Failed)) - .map(|task_info| { - json!({ - "id": task_info.task.id, - "name": get_task_name(task_info), - "error": task_info.error() - }) + .map(|task_info| FailedTaskInfo { + id: task_info.task.id.clone(), + name: get_task_name(task_info).to_string(), + error: task_info.error().cloned(), }) .collect(); + let event = TaskExecutionNotificationEvent::tasks_complete(stats, failed_tasks); + if let Err(e) = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { jsonrpc: "2.0".to_string(), method: "notifications/message".to_string(), params: Some(json!({ - "data": { - "type": "task_execution", - "subtype": "tasks_complete", - "stats": { - "total": total, - "completed": completed, - "failed": failed, - "success_rate": (completed as f64 / total as f64) * 100.0 - }, - "failed_tasks": failed_tasks - } + "data": event.to_notification_data() })), - })) { + })) + { tracing::warn!("Failed to send tasks complete notification: {}", e); } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index 43443804ddd9..60d022816680 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -15,7 +15,9 @@ pub async fn process_task( task: &Task, task_execution_tracker: Arc, ) -> TaskResult { - let timeout_in_seconds = task.timeout_in_seconds.unwrap_or(DEFAULT_TASK_TIMEOUT_SECONDS); + let timeout_in_seconds = task + .timeout_in_seconds + .unwrap_or(DEFAULT_TASK_TIMEOUT_SECONDS); let task_clone = task.clone(); let timeout_duration = Duration::from_secs(timeout_in_seconds); @@ -78,7 +80,7 @@ async fn get_task_result( fn build_command(task: &Task) -> Result<(Command, String), String> { let task_error = |field: &str| format!("Task {}: Missing {}", task.id, field); - + let mut output_identifier = task.id.clone(); let mut command = if task.task_type == "sub_recipe" { let sub_recipe_name = task @@ -186,11 +188,11 @@ fn spawn_output_reader( fn extract_json_from_line(line: &str) -> Option { let start = line.find('{')?; let end = line.rfind('}')?; - + if start >= end { return None; } - + let potential_json = &line[start..=end]; if serde_json::from_str::(potential_json).is_ok() { Some(potential_json.to_string()) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index e10c5adb889e..e558cee1f08b 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -6,6 +6,14 @@ use tokio::sync::mpsc; use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +#[serde(rename_all = "lowercase")] +pub enum ExecutionMode { + #[default] + Sequential, + Parallel, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Task { pub id: String, @@ -68,6 +76,17 @@ pub enum TaskStatus { Failed, } +impl std::fmt::Display for TaskStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TaskStatus::Pending => write!(f, "Pending"), + TaskStatus::Running => write!(f, "Running"), + TaskStatus::Completed => write!(f, "Completed"), + TaskStatus::Failed => write!(f, "Failed"), + } + } +} + #[derive(Debug, Clone)] pub struct TaskInfo { pub task: Task, diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index c2e39b15ade3..b86a69a8fcfe 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -9,7 +9,6 @@ pub fn get_task_name(task_info: &TaskInfo) -> &str { .unwrap_or(&task_info.task.id) } - pub fn count_by_status(tasks: &HashMap) -> (usize, usize, usize, usize, usize) { let total = tasks.len(); let (pending, running, completed, failed) = tasks.values().fold( diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 144c7831b32c..525433221ef2 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -281,13 +281,23 @@ impl Recipe { } } pub fn from_content(content: &str) -> Result { - let yaml_value = serde_yaml::from_str::(content) - .map_err(|_| anyhow::anyhow!("Unsupported format. Expected JSON or YAML."))?; - - // Check if there's a nested "recipe" key, otherwise use the root - let recipe_value = yaml_value.get("recipe").unwrap_or(&yaml_value); - - Ok(serde_yaml::from_value(recipe_value.clone())?) + if let Ok(json_value) = serde_json::from_str::(content) { + if let Some(nested_recipe) = json_value.get("recipe") { + Ok(serde_json::from_value(nested_recipe.clone())?) + } else { + Ok(serde_json::from_str(content)?) + } + } else if let Ok(yaml_value) = serde_yaml::from_str::(content) { + if let Some(nested_recipe) = yaml_value.get("recipe") { + Ok(serde_yaml::from_value(nested_recipe.clone())?) + } else { + Ok(serde_yaml::from_str(content)?) + } + } else { + Err(anyhow::anyhow!( + "Unsupported format. Expected JSON or YAML." + )) + } } } From 41c46a65073235b8153dede3d890218a63049d4a Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 18:43:33 +1000 Subject: [PATCH 21/34] added tests --- .../mod.rs} | 29 +- .../session/task_execution_display/tests.rs | 337 ++++++++++++++++++ .../{executor.rs => executor/mod.rs} | 3 + .../executor/tests.rs | 100 ++++++ .../{lib.rs => lib/mod.rs} | 3 + .../sub_recipe_execution_tool/lib/tests.rs | 216 +++++++++++ 6 files changed, 680 insertions(+), 8 deletions(-) rename crates/goose-cli/src/session/{task_execution_display.rs => task_execution_display/mod.rs} (89%) create mode 100644 crates/goose-cli/src/session/task_execution_display/tests.rs rename crates/goose/src/agents/sub_recipe_execution_tool/{executor.rs => executor/mod.rs} (99%) create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/executor/tests.rs rename crates/goose/src/agents/sub_recipe_execution_tool/{lib.rs => lib/mod.rs} (99%) create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/lib/tests.rs diff --git a/crates/goose-cli/src/session/task_execution_display.rs b/crates/goose-cli/src/session/task_execution_display/mod.rs similarity index 89% rename from crates/goose-cli/src/session/task_execution_display.rs rename to crates/goose-cli/src/session/task_execution_display/mod.rs index f8bad819f4d7..96d37b76d483 100644 --- a/crates/goose-cli/src/session/task_execution_display.rs +++ b/crates/goose-cli/src/session/task_execution_display/mod.rs @@ -5,6 +5,9 @@ use goose::agents::sub_recipe_execution_tool::notification_events::{ use serde_json::Value; use std::sync::atomic::{AtomicBool, Ordering}; +#[cfg(test)] +mod tests; + const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H"; const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H"; const CLEAR_TO_EOL: &str = "\x1b[K"; @@ -60,14 +63,24 @@ fn strip_ansi_codes(text: &str) -> String { while let Some(ch) = chars.next() { if ch == '\x1b' { - if chars.next() == Some('[') { - loop { - match chars.next() { - Some(c) if c.is_ascii_alphabetic() => break, - Some(_) => continue, - None => break, + if let Some(next_ch) = chars.next() { + if next_ch == '[' { + // This is an ANSI escape sequence, consume until alphabetic character + loop { + match chars.next() { + Some(c) if c.is_ascii_alphabetic() => break, + Some(_) => continue, + None => break, + } } + } else { + // Not an ANSI sequence, keep both characters + result.push(ch); + result.push(next_ch); } + } else { + // End of string after \x1b + result.push(ch); } } else { result.push(ch); @@ -130,7 +143,7 @@ fn format_tasks_update_from_event(event: &TaskExecutionNotificationEvent) -> Str sorted_tasks.sort_by(|a, b| a.id.cmp(&b.id)); for task in sorted_tasks { - display.push_str(&format_task_from_struct(&task)); + display.push_str(&format_task_display(&task)); } display.push_str(CLEAR_BELOW); @@ -172,7 +185,7 @@ fn format_tasks_complete_from_event(event: &TaskExecutionNotificationEvent) -> S } } -fn format_task_from_struct(task: &TaskInfo) -> String { +fn format_task_display(task: &TaskInfo) -> String { let mut task_display = String::new(); let status_icon = match task.status { diff --git a/crates/goose-cli/src/session/task_execution_display/tests.rs b/crates/goose-cli/src/session/task_execution_display/tests.rs new file mode 100644 index 000000000000..fb53285080d3 --- /dev/null +++ b/crates/goose-cli/src/session/task_execution_display/tests.rs @@ -0,0 +1,337 @@ +use super::*; +use goose::agents::sub_recipe_execution_tool::notification_events::{ + FailedTaskInfo, TaskCompletionStats, TaskExecutionStats, +}; +use serde_json::json; + +#[test] +fn test_strip_ansi_codes() { + assert_eq!(strip_ansi_codes("hello world"), "hello world"); + assert_eq!(strip_ansi_codes("\x1b[31mred text\x1b[0m"), "red text"); + assert_eq!( + strip_ansi_codes("\x1b[1;32mbold green\x1b[0m"), + "bold green" + ); + assert_eq!( + strip_ansi_codes("normal\x1b[33myellow\x1b[0mnormal"), + "normalyellownormal" + ); + assert_eq!(strip_ansi_codes("\x1bhello"), "\x1bhello"); + assert_eq!(strip_ansi_codes("hello\x1b"), "hello\x1b"); + assert_eq!(strip_ansi_codes(""), ""); +} + +#[test] +fn test_truncate_with_ellipsis() { + assert_eq!(truncate_with_ellipsis("hello", 10), "hello"); + assert_eq!(truncate_with_ellipsis("hello", 5), "hello"); + assert_eq!(truncate_with_ellipsis("hello world", 8), "hello..."); + assert_eq!(truncate_with_ellipsis("hello", 3), "..."); + assert_eq!(truncate_with_ellipsis("hello", 2), "..."); + assert_eq!(truncate_with_ellipsis("hello", 1), "..."); + assert_eq!(truncate_with_ellipsis("", 5), ""); +} + +#[test] +fn test_process_output_for_display() { + assert_eq!(process_output_for_display("hello world"), "hello world"); + assert_eq!( + process_output_for_display("line1\nline2"), + "line1 ... line2" + ); + + let input = "line1\nline2\nline3\nline4"; + let result = process_output_for_display(input); + assert_eq!(result, "line3 ... line4"); + + let long_line = "a".repeat(150); + let result = process_output_for_display(&long_line); + assert!(result.len() <= 100); + assert!(result.ends_with("...")); + + let ansi_output = "\x1b[31mred line 1\x1b[0m\n\x1b[32mgreen line 2\x1b[0m"; + let result = process_output_for_display(ansi_output); + assert_eq!(result, "red line 1 ... green line 2"); + + assert_eq!(process_output_for_display(""), ""); +} + +#[test] +fn test_format_result_data_for_display() { + let string_val = json!("hello world"); + assert_eq!(format_result_data_for_display(&string_val), "hello world"); + + let ansi_string = json!("\x1b[31mred text\x1b[0m"); + assert_eq!(format_result_data_for_display(&ansi_string), "red text"); + + assert_eq!(format_result_data_for_display(&json!(true)), "true"); + assert_eq!(format_result_data_for_display(&json!(false)), "false"); + assert_eq!(format_result_data_for_display(&json!(42)), "42"); + assert_eq!(format_result_data_for_display(&json!(3.14)), "3.14"); + assert_eq!(format_result_data_for_display(&json!(null)), "null"); + + let partial_obj = json!({ + "partial_output": "some output", + "other_field": "ignored" + }); + assert_eq!( + format_result_data_for_display(&partial_obj), + "Partial output: some output" + ); + + let obj = json!({"key": "value", "num": 42}); + let result = format_result_data_for_display(&obj); + assert!(result.contains("key")); + assert!(result.contains("value")); + + let arr = json!([1, 2, 3]); + let result = format_result_data_for_display(&arr); + assert!(result.contains("1")); + assert!(result.contains("2")); + assert!(result.contains("3")); +} + +#[test] +fn test_format_task_execution_notification_line_output() { + let _event = TaskExecutionNotificationEvent::LineOutput { + task_id: "task-1".to_string(), + output: "Hello World".to_string(), + }; + + let data = json!({ + "subtype": "line_output", + "task_id": "task-1", + "output": "Hello World" + }); + + let result = format_task_execution_notification(&data); + assert!(result.is_some()); + + let (formatted, second, third) = result.unwrap(); + assert_eq!(formatted, "Hello World\n"); + assert_eq!(second, None); + assert_eq!(third, Some("task_execution".to_string())); +} + +#[test] +fn test_format_task_execution_notification_invalid_data() { + let invalid_data = json!({ + "invalid": "structure" + }); + + let result = format_task_execution_notification(&invalid_data); + assert_eq!(result, None); + + let incomplete_data = json!({ + "subtype": "line_output" + }); + + let result = format_task_execution_notification(&incomplete_data); + assert_eq!(result, None); +} + +#[test] +fn test_format_tasks_update_from_event() { + INITIAL_SHOWN.store(false, Ordering::SeqCst); + + let stats = TaskExecutionStats::new(3, 1, 1, 1, 0); + let tasks = vec![ + TaskInfo { + id: "task-1".to_string(), + status: TaskStatus::Running, + duration_secs: Some(1.5), + current_output: "Processing...".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "test-task".to_string(), + task_metadata: "param=value".to_string(), + error: None, + result_data: None, + }, + TaskInfo { + id: "task-2".to_string(), + status: TaskStatus::Completed, + duration_secs: Some(2.3), + current_output: "".to_string(), + task_type: "text_instruction".to_string(), + task_name: "another-task".to_string(), + task_metadata: "".to_string(), + error: None, + result_data: Some(json!({"result": "success"})), + }, + ]; + + let event = TaskExecutionNotificationEvent::TasksUpdate { stats, tasks }; + let result = format_tasks_update_from_event(&event); + + assert!(result.contains("🎯 Task Execution Dashboard")); + assert!(result.contains("═══════════════════════════")); + assert!(result.contains("📊 Progress: 3 total")); + assert!(result.contains("⏳ 1 pending")); + assert!(result.contains("🏃 1 running")); + assert!(result.contains("✅ 1 completed")); + assert!(result.contains("❌ 0 failed")); + assert!(result.contains("🏃 test-task")); + assert!(result.contains("✅ another-task")); + assert!(result.contains("📋 Parameters: param=value")); + assert!(result.contains("⏱️ 1.5s")); + assert!(result.contains("💬 Processing...")); + + let result2 = format_tasks_update_from_event(&event); + assert!(!result2.contains("🎯 Task Execution Dashboard")); + assert!(result2.contains(MOVE_TO_PROGRESS_LINE)); +} + +#[test] +fn test_format_tasks_complete_from_event() { + let stats = TaskCompletionStats::new(5, 4, 1); + let failed_tasks = vec![FailedTaskInfo { + id: "task-3".to_string(), + name: "failed-task".to_string(), + error: Some("Connection timeout".to_string()), + }]; + + let event = TaskExecutionNotificationEvent::TasksComplete { + stats, + failed_tasks, + }; + let result = format_tasks_complete_from_event(&event); + + assert!(result.contains("Execution Complete!")); + assert!(result.contains("═══════════════════════")); + assert!(result.contains("Total Tasks: 5")); + assert!(result.contains("✅ Completed: 4")); + assert!(result.contains("❌ Failed: 1")); + assert!(result.contains("📈 Success Rate: 80.0%")); + assert!(result.contains("❌ Failed Tasks:")); + assert!(result.contains("• failed-task")); + assert!(result.contains("Error: Connection timeout")); + assert!(result.contains("📝 Generating summary...")); +} + +#[test] +fn test_format_tasks_complete_from_event_no_failures() { + let stats = TaskCompletionStats::new(3, 3, 0); + let failed_tasks = vec![]; + + let event = TaskExecutionNotificationEvent::TasksComplete { + stats, + failed_tasks, + }; + let result = format_tasks_complete_from_event(&event); + + assert!(!result.contains("❌ Failed Tasks:")); + assert!(result.contains("📈 Success Rate: 100.0%")); + assert!(result.contains("❌ Failed: 0")); +} + +#[test] +fn test_format_task_display_running() { + let task = TaskInfo { + id: "task-1".to_string(), + status: TaskStatus::Running, + duration_secs: Some(1.5), + current_output: "Processing data...\nAlmost done...".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "data-processor".to_string(), + task_metadata: "input=file.txt,output=result.json".to_string(), + error: None, + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(result.contains("🏃 data-processor (sub_recipe)")); + assert!(result.contains("📋 Parameters: input=file.txt,output=result.json")); + assert!(result.contains("⏱️ 1.5s")); + assert!(result.contains("💬 Processing data... ... Almost done...")); +} + +#[test] +fn test_format_task_display_completed() { + let task = TaskInfo { + id: "task-2".to_string(), + status: TaskStatus::Completed, + duration_secs: Some(3.2), + current_output: "".to_string(), + task_type: "text_instruction".to_string(), + task_name: "analyzer".to_string(), + task_metadata: "".to_string(), + error: None, + result_data: Some(json!({"status": "success", "count": 42})), + }; + + let result = format_task_display(&task); + + assert!(result.contains("✅ analyzer (text_instruction)")); + assert!(result.contains("⏱️ 3.2s")); + assert!(!result.contains("📋 Parameters")); + assert!(result.contains("📄")); +} + +#[test] +fn test_format_task_display_failed() { + let task = TaskInfo { + id: "task-3".to_string(), + status: TaskStatus::Failed, + duration_secs: None, + current_output: "".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "failing-task".to_string(), + task_metadata: "".to_string(), + error: Some( + "Network connection failed after multiple retries. The server is unreachable." + .to_string(), + ), + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(result.contains("❌ failing-task (sub_recipe)")); + assert!(!result.contains("⏱️")); + assert!(result.contains("⚠️")); + assert!(result.contains("Network connection failed after multiple retries")); +} + +#[test] +fn test_format_task_display_pending() { + let task = TaskInfo { + id: "task-4".to_string(), + status: TaskStatus::Pending, + duration_secs: None, + current_output: "".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "waiting-task".to_string(), + task_metadata: "priority=high".to_string(), + error: None, + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(result.contains("⏳ waiting-task (sub_recipe)")); + assert!(result.contains("📋 Parameters: priority=high")); + assert!(!result.contains("⏱️")); + assert!(!result.contains("💬")); + assert!(!result.contains("📄")); + assert!(!result.contains("⚠️")); +} + +#[test] +fn test_format_task_display_empty_current_output() { + let task = TaskInfo { + id: "task-5".to_string(), + status: TaskStatus::Running, + duration_secs: Some(0.5), + current_output: " \n\t \n ".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "quiet-task".to_string(), + task_metadata: "".to_string(), + error: None, + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(!result.contains("💬")); +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs similarity index 99% rename from crates/goose/src/agents/sub_recipe_execution_tool/executor.rs rename to crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs index 674bd13ae5e9..c6183b7c137b 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs @@ -13,6 +13,9 @@ use crate::agents::sub_recipe_execution_tool::task_execution_tracker::{ use crate::agents::sub_recipe_execution_tool::tasks::process_task; use crate::agents::sub_recipe_execution_tool::workers::spawn_worker; +#[cfg(test)] +mod tests; + const EXECUTION_STATUS_COMPLETED: &str = "completed"; pub async fn execute_single_task( diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor/tests.rs new file mode 100644 index 000000000000..76385b87ef37 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor/tests.rs @@ -0,0 +1,100 @@ +use super::{calculate_stats, create_empty_response, create_error_response}; +use crate::agents::sub_recipe_execution_tool::lib::{TaskResult, TaskStatus}; +use serde_json::json; + +fn create_test_task_result(task_id: &str, status: TaskStatus) -> TaskResult { + let is_failed = matches!(status, TaskStatus::Failed); + TaskResult { + task_id: task_id.to_string(), + status, + data: Some(json!({"output": "test output"})), + error: if is_failed { + Some("Test error".to_string()) + } else { + None + }, + } +} + +#[test] +fn test_calculate_stats() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed), + create_test_task_result("task2", TaskStatus::Completed), + create_test_task_result("task3", TaskStatus::Failed), + create_test_task_result("task4", TaskStatus::Completed), + ]; + + let stats = calculate_stats(&results, 1500); + + assert_eq!(stats.total_tasks, 4); + assert_eq!(stats.completed, 3); + assert_eq!(stats.failed, 1); + assert_eq!(stats.execution_time_ms, 1500); +} + +#[test] +fn test_calculate_stats_empty_results() { + let results = vec![]; + let stats = calculate_stats(&results, 0); + + assert_eq!(stats.total_tasks, 0); + assert_eq!(stats.completed, 0); + assert_eq!(stats.failed, 0); + assert_eq!(stats.execution_time_ms, 0); +} + +#[test] +fn test_calculate_stats_all_completed() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed), + create_test_task_result("task2", TaskStatus::Completed), + ]; + + let stats = calculate_stats(&results, 800); + + assert_eq!(stats.total_tasks, 2); + assert_eq!(stats.completed, 2); + assert_eq!(stats.failed, 0); + assert_eq!(stats.execution_time_ms, 800); +} + +#[test] +fn test_calculate_stats_all_failed() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Failed), + create_test_task_result("task2", TaskStatus::Failed), + ]; + + let stats = calculate_stats(&results, 1200); + + assert_eq!(stats.total_tasks, 2); + assert_eq!(stats.completed, 0); + assert_eq!(stats.failed, 2); + assert_eq!(stats.execution_time_ms, 1200); +} + +#[test] +fn test_create_empty_response() { + let response = create_empty_response(); + + assert_eq!(response.status, "completed"); + assert_eq!(response.results.len(), 0); + assert_eq!(response.stats.total_tasks, 0); + assert_eq!(response.stats.completed, 0); + assert_eq!(response.stats.failed, 0); + assert_eq!(response.stats.execution_time_ms, 0); +} + +#[test] +fn test_create_error_response() { + let error_msg = "Test error message"; + let response = create_error_response(error_msg.to_string()); + + assert_eq!(response.status, "failed"); + assert_eq!(response.results.len(), 0); + assert_eq!(response.stats.total_tasks, 0); + assert_eq!(response.stats.completed, 0); + assert_eq!(response.stats.failed, 1); + assert_eq!(response.stats.execution_time_ms, 0); +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs similarity index 99% rename from crates/goose/src/agents/sub_recipe_execution_tool/lib.rs rename to crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs index 4cded2c71000..0515bf37bb44 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs @@ -6,6 +6,9 @@ pub use crate::agents::sub_recipe_execution_tool::types::{ TaskStatus, }; +#[cfg(test)] +mod tests; + use mcp_core::protocol::JsonRpcMessage; use serde_json::Value; use tokio::sync::mpsc; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/tests.rs new file mode 100644 index 000000000000..957b11274279 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/tests.rs @@ -0,0 +1,216 @@ +use super::{ + extract_failed_tasks, format_error_summary, format_failed_task_error, get_task_description, + handle_response, +}; +use crate::agents::sub_recipe_execution_tool::lib::{ + ExecutionResponse, ExecutionStats, TaskResult, TaskStatus, +}; +use serde_json::json; + +fn create_test_task_result(task_id: &str, status: TaskStatus, error: Option) -> TaskResult { + TaskResult { + task_id: task_id.to_string(), + status, + data: Some(json!({"partial_output": "test output"})), + error, + } +} + +fn create_test_execution_response( + results: Vec, + failed_count: usize, +) -> ExecutionResponse { + ExecutionResponse { + status: "completed".to_string(), + results: results.clone(), + stats: ExecutionStats { + total_tasks: results.len(), + completed: results.len() - failed_count, + failed: failed_count, + execution_time_ms: 1000, + }, + } +} + +#[test] +fn test_extract_failed_tasks() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result( + "task2", + TaskStatus::Failed, + Some("Error message".to_string()), + ), + create_test_task_result("task3", TaskStatus::Completed, None), + create_test_task_result( + "task4", + TaskStatus::Failed, + Some("Another error".to_string()), + ), + ]; + + let failed_tasks = extract_failed_tasks(&results); + + assert_eq!(failed_tasks.len(), 2); + assert!(failed_tasks[0].contains("task2")); + assert!(failed_tasks[0].contains("Error message")); + assert!(failed_tasks[1].contains("task4")); + assert!(failed_tasks[1].contains("Another error")); +} + +#[test] +fn test_extract_failed_tasks_empty() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result("task2", TaskStatus::Completed, None), + ]; + + let failed_tasks = extract_failed_tasks(&results); + + assert_eq!(failed_tasks.len(), 0); +} + +#[test] +fn test_format_failed_task_error_with_error_message() { + let result = create_test_task_result( + "task1", + TaskStatus::Failed, + Some("Test error message".to_string()), + ); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("task1")); + assert!(formatted.contains("Test error message")); + assert!(formatted.contains("test output")); + assert!(formatted.contains("ID: task1")); +} + +#[test] +fn test_format_failed_task_error_without_error_message() { + let result = create_test_task_result("task2", TaskStatus::Failed, None); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("task2")); + assert!(formatted.contains("Unknown error")); + assert!(formatted.contains("test output")); +} + +#[test] +fn test_format_failed_task_error_empty_partial_output() { + let mut result = + create_test_task_result("task3", TaskStatus::Failed, Some("Error".to_string())); + result.data = Some(json!({"partial_output": ""})); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("No output captured")); +} + +#[test] +fn test_format_failed_task_error_no_partial_output() { + let mut result = + create_test_task_result("task4", TaskStatus::Failed, Some("Error".to_string())); + result.data = Some(json!({})); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("No output captured")); +} + +#[test] +fn test_format_failed_task_error_no_data() { + let mut result = + create_test_task_result("task5", TaskStatus::Failed, Some("Error".to_string())); + result.data = None; + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("No output captured")); +} + +#[test] +fn test_format_error_summary() { + let failed_tasks = vec![ + "Task 'task1': Error 1\nOutput: output1".to_string(), + "Task 'task2': Error 2\nOutput: output2".to_string(), + ]; + + let summary = format_error_summary(2, 5, failed_tasks); + + assert_eq!(summary, "2/5 tasks failed:\nTask 'task1': Error 1\nOutput: output1\nTask 'task2': Error 2\nOutput: output2"); +} + +#[test] +fn test_format_error_summary_single_failure() { + let failed_tasks = vec!["Task 'task1': Error\nOutput: output".to_string()]; + + let summary = format_error_summary(1, 3, failed_tasks); + + assert_eq!( + summary, + "1/3 tasks failed:\nTask 'task1': Error\nOutput: output" + ); +} + +#[test] +fn test_handle_response_success() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result("task2", TaskStatus::Completed, None), + ]; + let response = create_test_execution_response(results, 0); + + let result = handle_response(response); + + assert!(result.is_ok()); + let value = result.unwrap(); + assert_eq!(value["status"], "completed"); + assert_eq!(value["stats"]["failed"], 0); +} + +#[test] +fn test_handle_response_with_failures() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result("task2", TaskStatus::Failed, Some("Test error".to_string())), + ]; + let response = create_test_execution_response(results, 1); + + let result = handle_response(response); + + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.contains("1/2 tasks failed")); + assert!(error.contains("task2")); + assert!(error.contains("Test error")); +} + +#[test] +fn test_handle_response_all_failures() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Failed, Some("Error 1".to_string())), + create_test_task_result("task2", TaskStatus::Failed, Some("Error 2".to_string())), + ]; + let response = create_test_execution_response(results, 2); + + let result = handle_response(response); + + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.contains("2/2 tasks failed")); + assert!(error.contains("task1")); + assert!(error.contains("task2")); + assert!(error.contains("Error 1")); + assert!(error.contains("Error 2")); +} + +#[test] +fn test_get_task_description() { + let result = create_test_task_result("test_task_123", TaskStatus::Completed, None); + + let description = get_task_description(&result); + + assert_eq!(description, "ID: test_task_123"); +} From 2c0f24c450f980aba4559522185f76c6a2ffe846 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 19:32:59 +1000 Subject: [PATCH 22/34] revert some unexpected deletion --- crates/goose-cli/src/session/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 5ad837991837..61e434568e10 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -1015,6 +1015,9 @@ impl Session { } }; (formatted, Some(subagent_id.to_string()), Some(notification_type.to_string())) + } else if let Some(Value::String(output)) = o.get("output") { + // Fallback for other MCP notification types + (output.to_owned(), None, None) } else if let Some(result) = format_task_execution_notification(data) { result } else { From 41fc7ca86e8c52c0ef86331d298c1411734db4e4 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 19:45:20 +1000 Subject: [PATCH 23/34] removed initial worker size --- .../sub_recipe_execute_task_tool.rs | 3 --- .../goose/src/agents/sub_recipe_execution_tool/types.rs | 8 -------- 2 files changed, 11 deletions(-) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 80a537393e5b..10c35e7b2e32 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -103,9 +103,6 @@ Pre-created Task Based: "properties": { "max_workers": { "type": "number" - }, - "initial_workers": { - "type": "number" } } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index e558cee1f08b..ab0b496db469 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -128,30 +128,22 @@ impl SharedState { pub struct Config { #[serde(default = "default_max_workers")] pub max_workers: usize, - #[serde(default = "default_initial_workers")] - pub initial_workers: usize, } impl Default for Config { fn default() -> Self { Self { max_workers: default_max_workers(), - initial_workers: default_initial_workers(), } } } const DEFAULT_MAX_WORKERS: usize = 10; -const DEFAULT_INITIAL_WORKERS: usize = 2; fn default_max_workers() -> usize { DEFAULT_MAX_WORKERS } -fn default_initial_workers() -> usize { - DEFAULT_INITIAL_WORKERS -} - #[derive(Debug, Serialize)] pub struct ExecutionStats { pub total_tasks: usize, From 353c52a3f0848ff90bee083c2794eb57ed9f6078 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 19:55:43 +1000 Subject: [PATCH 24/34] removed worker config --- .claude/settings.local.json | 9 +++++++++ .../sub_recipe_execution_tool/executor/mod.rs | 7 +++---- .../sub_recipe_execution_tool/lib/mod.rs | 11 ++--------- .../sub_recipe_execute_task_tool.rs | 8 -------- .../agents/sub_recipe_execution_tool/types.rs | 19 ------------------- 5 files changed, 14 insertions(+), 40 deletions(-) create mode 100644 .claude/settings.local.json diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 000000000000..8a954032a156 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,9 @@ +{ + "permissions": { + "allow": [ + "Bash(grep:*)", + "Bash(awk:*)" + ], + "deny": [] + } +} \ No newline at end of file diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs index c6183b7c137b..bb73b73e9c54 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs @@ -5,7 +5,7 @@ use tokio::sync::mpsc; use tokio::time::Instant; use crate::agents::sub_recipe_execution_tool::lib::{ - Config, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, + ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; use crate::agents::sub_recipe_execution_tool::task_execution_tracker::{ DisplayMode, TaskExecutionTracker, @@ -17,6 +17,7 @@ use crate::agents::sub_recipe_execution_tool::workers::spawn_worker; mod tests; const EXECUTION_STATUS_COMPLETED: &str = "completed"; +const DEFAULT_MAX_WORKERS: usize = 10; pub async fn execute_single_task( task: &Task, @@ -41,7 +42,6 @@ pub async fn execute_single_task( pub async fn execute_tasks_in_parallel( tasks: Vec, - config: Config, notifier: mpsc::Sender, ) -> ExecutionResponse { let task_execution_tracker = Arc::new(TaskExecutionTracker::new( @@ -67,8 +67,7 @@ pub async fn execute_tasks_in_parallel( let shared_state = create_shared_state(task_rx, result_tx, task_execution_tracker.clone()); - // Simple static worker allocation - no dynamic scaling needed - let worker_count = std::cmp::min(task_count, config.max_workers); + let worker_count = std::cmp::min(task_count, DEFAULT_MAX_WORKERS); let mut worker_handles = Vec::new(); for i in 0..worker_count { let handle = spawn_worker(shared_state.clone(), i); diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs index 0515bf37bb44..b6a0361322cb 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs @@ -2,7 +2,7 @@ use crate::agents::sub_recipe_execution_tool::executor::{ execute_single_task, execute_tasks_in_parallel, }; pub use crate::agents::sub_recipe_execution_tool::types::{ - Config, ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, + ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; @@ -22,13 +22,6 @@ pub async fn execute_tasks( serde_json::from_value(input.get("tasks").ok_or("Missing tasks field")?.clone()) .map_err(|e| format!("Failed to parse tasks: {}", e))?; - let config: Config = if let Some(config_value) = input.get("config") { - serde_json::from_value(config_value.clone()) - .map_err(|e| format!("Failed to parse config: {}", e))? - } else { - Config::default() - }; - let task_count = tasks.len(); match execution_mode { ExecutionMode::Sequential => { @@ -41,7 +34,7 @@ pub async fn execute_tasks( } ExecutionMode::Parallel => { let response: ExecutionResponse = - execute_tasks_in_parallel(tasks, config, notifier).await; + execute_tasks_in_parallel(tasks, notifier).await; handle_response(response) } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 10c35e7b2e32..7a37fd51c4e3 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -97,14 +97,6 @@ Pre-created Task Based: "required": ["id", "payload"] }, "description": "The tasks to run in parallel" - }, - "config": { - "type": "object", - "properties": { - "max_workers": { - "type": "number" - } - } } }, "required": ["tasks"] diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index ab0b496db469..74cd8b6e6436 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -124,25 +124,6 @@ impl SharedState { } } -#[derive(Debug, Clone, Deserialize)] -pub struct Config { - #[serde(default = "default_max_workers")] - pub max_workers: usize, -} - -impl Default for Config { - fn default() -> Self { - Self { - max_workers: default_max_workers(), - } - } -} - -const DEFAULT_MAX_WORKERS: usize = 10; - -fn default_max_workers() -> usize { - DEFAULT_MAX_WORKERS -} #[derive(Debug, Serialize)] pub struct ExecutionStats { From 17a52138b5a9a4db042674032c4a5422a2ce5e52 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 9 Jul 2025 20:27:35 +1000 Subject: [PATCH 25/34] fixed fmt --- .claude/settings.local.json | 9 --------- .gitignore | 3 +++ .../src/agents/sub_recipe_execution_tool/lib/mod.rs | 6 ++---- .../goose/src/agents/sub_recipe_execution_tool/types.rs | 1 - 4 files changed, 5 insertions(+), 14 deletions(-) delete mode 100644 .claude/settings.local.json diff --git a/.claude/settings.local.json b/.claude/settings.local.json deleted file mode 100644 index 8a954032a156..000000000000 --- a/.claude/settings.local.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "permissions": { - "allow": [ - "Bash(grep:*)", - "Bash(awk:*)" - ], - "deny": [] - } -} \ No newline at end of file diff --git a/.gitignore b/.gitignore index caab83d726c7..41f629f09366 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,9 @@ ui/desktop/src/bin/goose_llm.dll # Hermit .hermit/ +# Claude +.claude + debug_*.txt # Docs diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs index b6a0361322cb..a4aedec568c7 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs @@ -2,8 +2,7 @@ use crate::agents::sub_recipe_execution_tool::executor::{ execute_single_task, execute_tasks_in_parallel, }; pub use crate::agents::sub_recipe_execution_tool::types::{ - ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, - TaskStatus, + ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; #[cfg(test)] @@ -33,8 +32,7 @@ pub async fn execute_tasks( } } ExecutionMode::Parallel => { - let response: ExecutionResponse = - execute_tasks_in_parallel(tasks, notifier).await; + let response: ExecutionResponse = execute_tasks_in_parallel(tasks, notifier).await; handle_response(response) } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs index 74cd8b6e6436..ea31746032d7 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -124,7 +124,6 @@ impl SharedState { } } - #[derive(Debug, Serialize)] pub struct ExecutionStats { pub total_tasks: usize, From 4f52ab1e1a3942c038896a93985bc51fb69861b3 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 11 Jul 2025 16:58:19 +1000 Subject: [PATCH 26/34] renamed types.rs to task_types.rs --- crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs | 2 +- crates/goose/src/agents/sub_recipe_execution_tool/mod.rs | 2 +- .../agents/sub_recipe_execution_tool/notification_events.rs | 2 +- .../sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs | 4 ++-- .../sub_recipe_execution_tool/task_execution_tracker.rs | 4 +++- .../sub_recipe_execution_tool/{types.rs => task_types.rs} | 0 crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs | 2 +- .../goose/src/agents/sub_recipe_execution_tool/utils/mod.rs | 2 +- .../goose/src/agents/sub_recipe_execution_tool/utils/tests.rs | 2 +- crates/goose/src/agents/sub_recipe_execution_tool/workers.rs | 2 +- 10 files changed, 12 insertions(+), 10 deletions(-) rename crates/goose/src/agents/sub_recipe_execution_tool/{types.rs => task_types.rs} (100%) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs index a4aedec568c7..746e4ecb151c 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs @@ -1,7 +1,7 @@ use crate::agents::sub_recipe_execution_tool::executor::{ execute_single_task, execute_tasks_in_parallel, }; -pub use crate::agents::sub_recipe_execution_tool::types::{ +pub use crate::agents::sub_recipe_execution_tool::task_types::{ ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs index 267caab8f115..b6363ba20d14 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs @@ -3,7 +3,7 @@ pub mod lib; pub mod notification_events; pub mod sub_recipe_execute_task_tool; mod task_execution_tracker; +mod task_types; mod tasks; -mod types; pub mod utils; mod workers; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs b/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs index 97a4576b2d99..2a6134ea1a55 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs @@ -1,4 +1,4 @@ -use crate::agents::sub_recipe_execution_tool::types::TaskStatus; +use crate::agents::sub_recipe_execution_tool::task_types::TaskStatus; use serde::{Deserialize, Serialize}; use serde_json::Value; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 7a37fd51c4e3..62054cca18e6 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -2,8 +2,8 @@ use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; use serde_json::Value; use crate::agents::{ - sub_recipe_execution_tool::lib::execute_tasks, sub_recipe_execution_tool::types::ExecutionMode, - tool_execution::ToolCallResult, + sub_recipe_execution_tool::lib::execute_tasks, + sub_recipe_execution_tool::task_types::ExecutionMode, tool_execution::ToolCallResult, }; use mcp_core::protocol::JsonRpcMessage; use tokio::sync::mpsc; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs index 639e4580b843..b456fd77424f 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -9,7 +9,9 @@ use crate::agents::sub_recipe_execution_tool::notification_events::{ FailedTaskInfo, TaskCompletionStats, TaskExecutionNotificationEvent, TaskExecutionStats, TaskInfo as EventTaskInfo, }; -use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskResult, TaskStatus}; +use crate::agents::sub_recipe_execution_tool::task_types::{ + Task, TaskInfo, TaskResult, TaskStatus, +}; use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; use serde_json::Value; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs similarity index 100% rename from crates/goose/src/agents/sub_recipe_execution_tool/types.rs rename to crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index 60d022816680..fb41b0632d18 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -7,7 +7,7 @@ use tokio::process::Command; use tokio::time::timeout; use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; -use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult, TaskStatus}; +use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskResult, TaskStatus}; const DEFAULT_TASK_TIMEOUT_SECONDS: u64 = 300; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs index b86a69a8fcfe..1ead865e6571 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use crate::agents::sub_recipe_execution_tool::types::{TaskInfo, TaskStatus}; +use crate::agents::sub_recipe_execution_tool::task_types::{TaskInfo, TaskStatus}; pub fn get_task_name(task_info: &TaskInfo) -> &str { task_info diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs index 618026a7f003..f799b699aaca 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -1,4 +1,4 @@ -use crate::agents::sub_recipe_execution_tool::types::{Task, TaskInfo, TaskStatus}; +use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskInfo, TaskStatus}; use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; use serde_json::json; use std::collections::HashMap; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index 35e9f6d22219..fefbf0eb82b2 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -1,5 +1,5 @@ +use crate::agents::sub_recipe_execution_tool::task_types::{SharedState, Task}; use crate::agents::sub_recipe_execution_tool::tasks::process_task; -use crate::agents::sub_recipe_execution_tool::types::{SharedState, Task}; use std::sync::Arc; async fn receive_task(state: &SharedState) -> Option { From 0e016303a7a00e120b7942f8ec413328e18dde9f Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Tue, 15 Jul 2025 09:47:14 -0700 Subject: [PATCH 27/34] feat: dynamic tasks (#3414) --- crates/goose-cli/src/session/builder.rs | 2 + crates/goose/src/agents/agent.rs | 8 +- .../agents/recipe_tools/dynamic_task_tools.rs | 138 ++++++++++++++++++ crates/goose/src/agents/recipe_tools/mod.rs | 1 + .../agents/recipe_tools/sub_recipe_tools.rs | 4 +- 5 files changed, 149 insertions(+), 4 deletions(-) create mode 100644 crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index b5d3da1d7f45..0377a36e905e 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -187,6 +187,8 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { // Create the agent let agent: Agent = Agent::new(); + + // Sub-recipes if let Some(sub_recipes) = session_config.sub_recipes { agent.add_sub_recipes(sub_recipes).await; } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 38a488a9781a..b9592bff808a 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -9,6 +9,9 @@ use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; use mcp_core::protocol::JsonRpcMessage; use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME}; +use crate::agents::recipe_tools::dynamic_task_tools::{ + create_dynamic_task, create_dynamic_task_tool, DYNAMIC_TASK_TOOL_NAME_PREFIX, +}; use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{ self, SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, }; @@ -53,7 +56,6 @@ use super::final_output_tool::FinalOutputTool; use super::platform_tools; use super::router_tools; use super::subagent_manager::SubAgentManager; -use super::subagent_tools; use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; const DEFAULT_MAX_TURNS: u32 = 1000; @@ -295,6 +297,8 @@ impl Agent { .await } else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME { sub_recipe_execute_task_tool::run_tasks(tool_call.arguments.clone()).await + } else if tool_call.name == DYNAMIC_TASK_TOOL_NAME_PREFIX { + create_dynamic_task(tool_call.arguments.clone()).await } else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { // Check if the tool is read_resource and handle it separately ToolCallResult::from( @@ -559,7 +563,7 @@ impl Agent { // Add subagent tool (only if ALPHA_FEATURES is enabled) let config = Config::global(); if config.get_param::("ALPHA_FEATURES").unwrap_or(false) { - prefixed_tools.push(subagent_tools::run_task_subagent_tool()); + prefixed_tools.push(create_dynamic_task_tool()); } // Add resource tools if supported diff --git a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs new file mode 100644 index 000000000000..77a66bee3fac --- /dev/null +++ b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs @@ -0,0 +1,138 @@ +// ======================================= +// Module: Dynamic Task Tools +// Handles creation of tasks dynamically without sub-recipes +// ======================================= +use crate::agents::recipe_tools::sub_recipe_tools::{ + EXECUTION_MODE_PARALLEL, EXECUTION_MODE_SEQUENTIAL, +}; +use crate::agents::sub_recipe_execution_tool::lib::Task; +use crate::agents::tool_execution::ToolCallResult; +use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; +use serde_json::{json, Value}; + +pub const DYNAMIC_TASK_TOOL_NAME_PREFIX: &str = "dynamic_task__create_task"; + +pub fn create_dynamic_task_tool() -> Tool { + Tool::new( + format!("{}", DYNAMIC_TASK_TOOL_NAME_PREFIX), + format!( + "Creates a dynamic task object(s) based on textual instructions. \ + Provide an array of parameter sets in the 'task_parameters' field:\n\ + - For a single task: provide an array with one parameter set\n\ + - For multiple tasks: provide an array with multiple parameter sets, each with different values\n\n\ + Each task will run the same text instruction but with different parameter values. \ + This is useful when you need to execute the same instruction multiple times with varying inputs. \ + After creating the task list, pass it to the task executor to run all tasks." + ), + json!({ + "type": "object", + "properties": { + "task_parameters": { + "type": "array", + "description": "Array of parameter sets for creating tasks. \ + For a single task, provide an array with one element. \ + For multiple tasks, provide an array with multiple elements, each with different parameter values. \ + If there is no parameter set, provide an empty array.", + "items": { + "type": "object", + "properties": { + "text_instruction": { + "type": "string", + "description": "The text instruction to execute" + }, + "timeout_seconds": { + "type": "integer", + "description": "Optional timeout for the task in seconds (default: 300)", + "minimum": 1 + } + }, + "required": ["text_instruction"] + } + } + } + }), + Some(ToolAnnotations { + title: Some(format!("Dynamic Task Creation")), + read_only_hint: false, + destructive_hint: true, + idempotent_hint: false, + open_world_hint: true, + }), + ) +} + +fn extract_task_parameters(params: &Value) -> Vec { + params + .get("task_parameters") + .and_then(|v| v.as_array()) + .cloned() + .unwrap_or_default() +} + +fn create_text_instruction_tasks_from_params(task_params: &[Value]) -> Vec { + task_params + .iter() + .map(|task_param| { + let text_instruction = task_param + .get("text_instruction") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let timeout_seconds = task_param + .get("timeout_seconds") + .and_then(|v| v.as_u64()) + .unwrap_or(300); + + let payload = json!({ + "text_instruction": text_instruction + }); + + Task { + id: uuid::Uuid::new_v4().to_string(), + task_type: "text_instruction".to_string(), + timeout_in_seconds: Some(timeout_seconds), + payload, + } + }) + .collect() +} + +fn create_task_execution_payload(tasks: Vec, execution_mode: &str) -> Value { + json!({ + "tasks": tasks, + "execution_mode": execution_mode + }) +} + +pub async fn create_dynamic_task(params: Value) -> ToolCallResult { + let task_params_array = extract_task_parameters(¶ms); + + if task_params_array.is_empty() { + return ToolCallResult::from(Err(ToolError::ExecutionError( + "No task parameters provided".to_string(), + ))); + } + + let tasks = create_text_instruction_tasks_from_params(&task_params_array); + + // Use parallel execution if there are multiple tasks, sequential for single task + let execution_mode = if tasks.len() > 1 { + EXECUTION_MODE_PARALLEL + } else { + EXECUTION_MODE_SEQUENTIAL + }; + + let task_execution_payload = create_task_execution_payload(tasks, execution_mode); + + let tasks_json = match serde_json::to_string(&task_execution_payload) { + Ok(json) => json, + Err(e) => { + return ToolCallResult::from(Err(ToolError::ExecutionError(format!( + "Failed to serialize task list: {}", + e + )))) + } + }; + ToolCallResult::from(Ok(vec![Content::text(tasks_json)])) +} diff --git a/crates/goose/src/agents/recipe_tools/mod.rs b/crates/goose/src/agents/recipe_tools/mod.rs index 90603c88488e..6e6f28a80310 100644 --- a/crates/goose/src/agents/recipe_tools/mod.rs +++ b/crates/goose/src/agents/recipe_tools/mod.rs @@ -1,2 +1,3 @@ +pub mod dynamic_task_tools; pub mod param_utils; pub mod sub_recipe_tools; diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index fd0edaeaecd1..bc0347dfd509 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -11,8 +11,8 @@ use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubReci use super::param_utils::prepare_command_params; pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; -const EXECUTION_MODE_PARALLEL: &str = "parallel"; -const EXECUTION_MODE_SEQUENTIAL: &str = "sequential"; +pub const EXECUTION_MODE_PARALLEL: &str = "parallel"; +pub const EXECUTION_MODE_SEQUENTIAL: &str = "sequential"; pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { let input_schema = get_input_schema(sub_recipe).unwrap(); From adf378a6e355fd83bdfb6dd211359e525dc12e5b Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 16 Jul 2025 10:24:58 +1000 Subject: [PATCH 28/34] Revert "feat: dynamic tasks (#3414)" This reverts commit 0e016303a7a00e120b7942f8ec413328e18dde9f. --- crates/goose-cli/src/session/builder.rs | 2 - crates/goose/src/agents/agent.rs | 8 +- .../agents/recipe_tools/dynamic_task_tools.rs | 138 ------------------ crates/goose/src/agents/recipe_tools/mod.rs | 1 - .../agents/recipe_tools/sub_recipe_tools.rs | 4 +- 5 files changed, 4 insertions(+), 149 deletions(-) delete mode 100644 crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index 0377a36e905e..b5d3da1d7f45 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -187,8 +187,6 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { // Create the agent let agent: Agent = Agent::new(); - - // Sub-recipes if let Some(sub_recipes) = session_config.sub_recipes { agent.add_sub_recipes(sub_recipes).await; } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index b9592bff808a..38a488a9781a 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -9,9 +9,6 @@ use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; use mcp_core::protocol::JsonRpcMessage; use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME}; -use crate::agents::recipe_tools::dynamic_task_tools::{ - create_dynamic_task, create_dynamic_task_tool, DYNAMIC_TASK_TOOL_NAME_PREFIX, -}; use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{ self, SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, }; @@ -56,6 +53,7 @@ use super::final_output_tool::FinalOutputTool; use super::platform_tools; use super::router_tools; use super::subagent_manager::SubAgentManager; +use super::subagent_tools; use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; const DEFAULT_MAX_TURNS: u32 = 1000; @@ -297,8 +295,6 @@ impl Agent { .await } else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME { sub_recipe_execute_task_tool::run_tasks(tool_call.arguments.clone()).await - } else if tool_call.name == DYNAMIC_TASK_TOOL_NAME_PREFIX { - create_dynamic_task(tool_call.arguments.clone()).await } else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { // Check if the tool is read_resource and handle it separately ToolCallResult::from( @@ -563,7 +559,7 @@ impl Agent { // Add subagent tool (only if ALPHA_FEATURES is enabled) let config = Config::global(); if config.get_param::("ALPHA_FEATURES").unwrap_or(false) { - prefixed_tools.push(create_dynamic_task_tool()); + prefixed_tools.push(subagent_tools::run_task_subagent_tool()); } // Add resource tools if supported diff --git a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs deleted file mode 100644 index 77a66bee3fac..000000000000 --- a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs +++ /dev/null @@ -1,138 +0,0 @@ -// ======================================= -// Module: Dynamic Task Tools -// Handles creation of tasks dynamically without sub-recipes -// ======================================= -use crate::agents::recipe_tools::sub_recipe_tools::{ - EXECUTION_MODE_PARALLEL, EXECUTION_MODE_SEQUENTIAL, -}; -use crate::agents::sub_recipe_execution_tool::lib::Task; -use crate::agents::tool_execution::ToolCallResult; -use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; -use serde_json::{json, Value}; - -pub const DYNAMIC_TASK_TOOL_NAME_PREFIX: &str = "dynamic_task__create_task"; - -pub fn create_dynamic_task_tool() -> Tool { - Tool::new( - format!("{}", DYNAMIC_TASK_TOOL_NAME_PREFIX), - format!( - "Creates a dynamic task object(s) based on textual instructions. \ - Provide an array of parameter sets in the 'task_parameters' field:\n\ - - For a single task: provide an array with one parameter set\n\ - - For multiple tasks: provide an array with multiple parameter sets, each with different values\n\n\ - Each task will run the same text instruction but with different parameter values. \ - This is useful when you need to execute the same instruction multiple times with varying inputs. \ - After creating the task list, pass it to the task executor to run all tasks." - ), - json!({ - "type": "object", - "properties": { - "task_parameters": { - "type": "array", - "description": "Array of parameter sets for creating tasks. \ - For a single task, provide an array with one element. \ - For multiple tasks, provide an array with multiple elements, each with different parameter values. \ - If there is no parameter set, provide an empty array.", - "items": { - "type": "object", - "properties": { - "text_instruction": { - "type": "string", - "description": "The text instruction to execute" - }, - "timeout_seconds": { - "type": "integer", - "description": "Optional timeout for the task in seconds (default: 300)", - "minimum": 1 - } - }, - "required": ["text_instruction"] - } - } - } - }), - Some(ToolAnnotations { - title: Some(format!("Dynamic Task Creation")), - read_only_hint: false, - destructive_hint: true, - idempotent_hint: false, - open_world_hint: true, - }), - ) -} - -fn extract_task_parameters(params: &Value) -> Vec { - params - .get("task_parameters") - .and_then(|v| v.as_array()) - .cloned() - .unwrap_or_default() -} - -fn create_text_instruction_tasks_from_params(task_params: &[Value]) -> Vec { - task_params - .iter() - .map(|task_param| { - let text_instruction = task_param - .get("text_instruction") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - - let timeout_seconds = task_param - .get("timeout_seconds") - .and_then(|v| v.as_u64()) - .unwrap_or(300); - - let payload = json!({ - "text_instruction": text_instruction - }); - - Task { - id: uuid::Uuid::new_v4().to_string(), - task_type: "text_instruction".to_string(), - timeout_in_seconds: Some(timeout_seconds), - payload, - } - }) - .collect() -} - -fn create_task_execution_payload(tasks: Vec, execution_mode: &str) -> Value { - json!({ - "tasks": tasks, - "execution_mode": execution_mode - }) -} - -pub async fn create_dynamic_task(params: Value) -> ToolCallResult { - let task_params_array = extract_task_parameters(¶ms); - - if task_params_array.is_empty() { - return ToolCallResult::from(Err(ToolError::ExecutionError( - "No task parameters provided".to_string(), - ))); - } - - let tasks = create_text_instruction_tasks_from_params(&task_params_array); - - // Use parallel execution if there are multiple tasks, sequential for single task - let execution_mode = if tasks.len() > 1 { - EXECUTION_MODE_PARALLEL - } else { - EXECUTION_MODE_SEQUENTIAL - }; - - let task_execution_payload = create_task_execution_payload(tasks, execution_mode); - - let tasks_json = match serde_json::to_string(&task_execution_payload) { - Ok(json) => json, - Err(e) => { - return ToolCallResult::from(Err(ToolError::ExecutionError(format!( - "Failed to serialize task list: {}", - e - )))) - } - }; - ToolCallResult::from(Ok(vec![Content::text(tasks_json)])) -} diff --git a/crates/goose/src/agents/recipe_tools/mod.rs b/crates/goose/src/agents/recipe_tools/mod.rs index 6e6f28a80310..90603c88488e 100644 --- a/crates/goose/src/agents/recipe_tools/mod.rs +++ b/crates/goose/src/agents/recipe_tools/mod.rs @@ -1,3 +1,2 @@ -pub mod dynamic_task_tools; pub mod param_utils; pub mod sub_recipe_tools; diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index bc0347dfd509..fd0edaeaecd1 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -11,8 +11,8 @@ use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubReci use super::param_utils::prepare_command_params; pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; -pub const EXECUTION_MODE_PARALLEL: &str = "parallel"; -pub const EXECUTION_MODE_SEQUENTIAL: &str = "sequential"; +const EXECUTION_MODE_PARALLEL: &str = "parallel"; +const EXECUTION_MODE_SEQUENTIAL: &str = "sequential"; pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { let input_schema = get_input_schema(sub_recipe).unwrap(); From 7dd8587286f210d261ba3381be6e85a322130e20 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 16 Jul 2025 11:49:32 +1000 Subject: [PATCH 29/34] remove runs and timeout_in_seconds --- Cargo.lock | 4 +- .../goose-cli/src/recipes/extract_from_cli.rs | 1 - .../agents/recipe_tools/param_utils/mod.rs | 103 ++------------- .../agents/recipe_tools/param_utils/tests.rs | 120 +++--------------- .../agents/recipe_tools/sub_recipe_tools.rs | 22 +--- .../recipe_tools/sub_recipe_tools/tests.rs | 22 +--- .../sub_recipe_execute_task_tool.rs | 4 - .../sub_recipe_execution_tool/task_types.rs | 1 - .../agents/sub_recipe_execution_tool/tasks.rs | 37 +----- .../sub_recipe_execution_tool/utils/tests.rs | 5 - crates/goose/src/recipe/mod.rs | 8 -- 11 files changed, 41 insertions(+), 286 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9541e0b7dbf9..d07623aede26 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8604,9 +8604,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.13" +version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" +checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" dependencies = [ "bytes", "futures-core", diff --git a/crates/goose-cli/src/recipes/extract_from_cli.rs b/crates/goose-cli/src/recipes/extract_from_cli.rs index 8ba8658c8cb6..84d578c0cae4 100644 --- a/crates/goose-cli/src/recipes/extract_from_cli.rs +++ b/crates/goose-cli/src/recipes/extract_from_cli.rs @@ -31,7 +31,6 @@ pub fn extract_recipe_info_from_cli( let additional_sub_recipe = SubRecipe { path: recipe_file_path.to_string_lossy().to_string(), name, - timeout_in_seconds: None, values: None, executions: None, }; diff --git a/crates/goose/src/agents/recipe_tools/param_utils/mod.rs b/crates/goose/src/agents/recipe_tools/param_utils/mod.rs index e7891c1bb5e7..bd8468c032dd 100644 --- a/crates/goose/src/agents/recipe_tools/param_utils/mod.rs +++ b/crates/goose/src/agents/recipe_tools/param_utils/mod.rs @@ -4,113 +4,34 @@ use std::collections::HashMap; use crate::recipe::SubRecipe; -pub fn extract_run_params( +pub fn prepare_command_params( sub_recipe: &SubRecipe, -) -> (HashMap, Vec>) { + params_from_tool_call: Vec, +) -> Result>> { let base_params = sub_recipe.values.clone().unwrap_or_default(); - let run_params = sub_recipe - .executions - .as_ref() - .and_then(|e| e.runs.as_ref()) - .map(|runs| { - runs.iter() - .map(|run| { - let mut params = base_params.clone(); - if let Some(run_values) = &run.values { - params.extend(run_values.clone()); - } - params - }) - .collect::>() - }) - .unwrap_or_default(); - (base_params, run_params) -} - -pub fn validate_param_counts( - run_params: &[HashMap], - params_from_tool_call: &[Value], -) -> Result<()> { - let has_run_params = !run_params.is_empty(); - let multiple_params_from_tool_call = params_from_tool_call.len() > 1; - let count_mismatch = run_params.len() != params_from_tool_call.len(); - - if has_run_params && multiple_params_from_tool_call && count_mismatch { - return Err(anyhow::anyhow!( - "The number of runs in the sub recipe ({}) does not match the number of task parameters ({})", - run_params.len(), - params_from_tool_call.len() - )); - } - Ok(()) -} - -pub fn prepare_base_params( - base_params: &HashMap, - run_params: &[HashMap], - params_from_tool_call: &[Value], -) -> Vec> { - if run_params.is_empty() { - vec![base_params.clone(); params_from_tool_call.len()] - } else { - run_params.to_vec() - } -} - -pub fn prepare_tool_params( - params_from_tool_call: &[Value], - run_params: &[HashMap], -) -> Vec { - if params_from_tool_call.len() == 1 && run_params.len() > 1 { - vec![params_from_tool_call[0].clone(); run_params.len()] - } else { - params_from_tool_call.to_vec() + if params_from_tool_call.is_empty() { + return Ok(vec![base_params]); } -} -pub fn merge_parameters( - tool_params: Vec, - base_params: Vec>, -) -> Vec> { - tool_params + let result = params_from_tool_call .into_iter() - .zip(base_params) - .map(|(tool_param, mut base_param_map)| { + .map(|tool_param| { + let mut param_map = base_params.clone(); if let Some(param_obj) = tool_param.as_object() { for (key, value) in param_obj { let value_str = value .as_str() .map(String::from) .unwrap_or_else(|| value.to_string()); - base_param_map.entry(key.clone()).or_insert(value_str); + param_map.entry(key.clone()).or_insert(value_str); } } - base_param_map + param_map }) - .collect() -} - -pub fn prepare_command_params( - sub_recipe: &SubRecipe, - params_from_tool_call: Vec, -) -> Result>> { - let (base_params, run_params) = extract_run_params(sub_recipe); - - if params_from_tool_call.is_empty() { - return Ok(run_params); - } - - validate_param_counts(&run_params, ¶ms_from_tool_call)?; - - let base_params_for_merging = - prepare_base_params(&base_params, &run_params, ¶ms_from_tool_call); - let tool_params_for_merging = prepare_tool_params(¶ms_from_tool_call, &run_params); + .collect(); - Ok(merge_parameters( - tool_params_for_merging, - base_params_for_merging, - )) + Ok(result) } #[cfg(test)] diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs index 9cbec0c7b828..d6a0182c091b 100644 --- a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use crate::recipe::{Execution, ExecutionRun, SubRecipe}; +use crate::recipe::SubRecipe; use serde_json::json; use crate::agents::recipe_tools::param_utils::prepare_command_params; @@ -9,26 +9,12 @@ fn setup_default_sub_recipe() -> SubRecipe { let sub_recipe = SubRecipe { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), - timeout_in_seconds: None, values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), executions: None, }; sub_recipe } -fn create_execution_values(key: &str, values: Vec) -> Execution { - let runs = values - .iter() - .map(|value| ExecutionRun { - values: Some(HashMap::from([(key.to_string(), value.to_string())])), - }) - .collect(); - Execution { - parallel: true, - runs: Some(runs), - } -} - mod prepare_command_params_tests { use super::*; @@ -83,35 +69,13 @@ mod prepare_command_params_tests { sub_recipe.values = None; let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!(result.len(), 0); + assert_eq!(result, vec![HashMap::new()]); } } - mod with_execution_runs { + mod multiple_tool_parameters { use super::*; - #[test] - fn test_return_command_param() { - let parameter_array = vec![json!(HashMap::from([( - "key3".to_string(), - "value3".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - sub_recipe.executions = - Some(create_execution_values("key2", vec!["value2".to_string()])); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()), - ("key3".to_string(), "value3".to_string()) - ]),], - result - ); - } - #[test] fn test_return_command_param_when_all_values_from_tool_call_parameters() { let parameter_array = vec![ @@ -126,7 +90,6 @@ mod prepare_command_params_tests { ]; let mut sub_recipe = setup_default_sub_recipe(); sub_recipe.values = None; - sub_recipe.executions = None; let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); assert_eq!( @@ -145,84 +108,31 @@ mod prepare_command_params_tests { } #[test] - fn test_return_command_param_when_all_from_values_in_sub_recipe() { - let parameter_array = vec![]; + fn test_merge_base_values_with_tool_parameters() { + let parameter_array = vec![ + json!(HashMap::from([("key2".to_string(), "override_value1".to_string())])), + json!(HashMap::from([("key2".to_string(), "override_value2".to_string())])), + ]; let mut sub_recipe = setup_default_sub_recipe(); sub_recipe.values = Some(HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key3".to_string(), "value3".to_string()), + ("key1".to_string(), "base_value".to_string()), + ("key2".to_string(), "original_value".to_string()), ])); - sub_recipe.executions = Some(create_execution_values( - "key2", - vec!["key2_value1".to_string(), "key2_value2".to_string()], - )); let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); assert_eq!( vec![ HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()), - ("key3".to_string(), "value3".to_string()), + ("key1".to_string(), "base_value".to_string()), + ("key2".to_string(), "original_value".to_string()), ]), HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value2".to_string()), - ("key3".to_string(), "value3".to_string()), - ]) - ], - result - ); - } - - #[test] - fn test_return_command_param_when_tool_call_parameters_has_one_item_and_execution_runs_has_multiple_items( - ) { - let parameter_array = vec![json!(HashMap::from([( - "key3".to_string(), - "value3".to_string() - ),]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - sub_recipe.executions = Some(create_execution_values( - "key2", - vec!["key2_value1".to_string(), "key2_value2".to_string()], - )); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![ - HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value1".to_string()), - ("key3".to_string(), "value3".to_string()), + ("key1".to_string(), "base_value".to_string()), + ("key2".to_string(), "original_value".to_string()), ]), - HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "key2_value2".to_string()), - ("key3".to_string(), "value3".to_string()), - ]) ], result ); } - - #[test] - fn test_throw_error_when_execution_runs_value_length_not_match_with_tool_call_parameters() { - let parameter_array = vec![ - json!(HashMap::from([("key3".to_string(), "value3".to_string())])), - json!(HashMap::from([("key4".to_string(), "value4".to_string())])), - ]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - sub_recipe.executions = Some(create_execution_values( - "key2", - vec!["key2_value1".to_string()], - )); - - let result = prepare_command_params(&sub_recipe, parameter_array); - - assert!(result.is_err()); - } } -} +} \ No newline at end of file diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index fd0edaeaecd1..a9beb1355266 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -5,14 +5,12 @@ use anyhow::Result; use mcp_core::tool::{Tool, ToolAnnotations}; use serde_json::{json, Map, Value}; -use crate::agents::sub_recipe_execution_tool::lib::Task; +use crate::agents::sub_recipe_execution_tool::lib::{ExecutionMode, Task}; use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubRecipe}; use super::param_utils::prepare_command_params; pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; -const EXECUTION_MODE_PARALLEL: &str = "parallel"; -const EXECUTION_MODE_SEQUENTIAL: &str = "sequential"; pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { let input_schema = get_input_schema(sub_recipe).unwrap(); @@ -64,14 +62,13 @@ fn create_tasks_from_params( Task { id: uuid::Uuid::new_v4().to_string(), task_type: "sub_recipe".to_string(), - timeout_in_seconds: sub_recipe.timeout_in_seconds, payload, } }) .collect() } -fn get_execution_mode(sub_recipe: &SubRecipe) -> &'static str { +fn get_execution_mode(sub_recipe: &SubRecipe) -> ExecutionMode { let is_parallel = sub_recipe .executions .as_ref() @@ -79,13 +76,13 @@ fn get_execution_mode(sub_recipe: &SubRecipe) -> &'static str { .unwrap_or(false); if is_parallel { - EXECUTION_MODE_PARALLEL + ExecutionMode::Parallel } else { - EXECUTION_MODE_SEQUENTIAL + ExecutionMode::Sequential } } -fn create_task_execution_payload(tasks: Vec, execution_mode: &str) -> Value { +fn create_task_execution_payload(tasks: Vec, execution_mode: ExecutionMode) -> Value { json!({ "tasks": tasks, "execution_mode": execution_mode @@ -120,15 +117,6 @@ fn get_params_with_values(sub_recipe: &SubRecipe) -> HashSet { sub_recipe_params_with_values.insert(param_name.clone()); } } - if let Some(runs) = sub_recipe.executions.as_ref().and_then(|e| e.runs.as_ref()) { - for run in runs { - if let Some(params_with_value) = &run.values { - for param_name in params_with_value.keys() { - sub_recipe_params_with_values.insert(param_name.clone()); - } - } - } - } sub_recipe_params_with_values } diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index ca66e97819bb..21b4d1710fd2 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -2,7 +2,7 @@ mod tests { use std::collections::HashMap; - use crate::recipe::{Execution, ExecutionRun, SubRecipe}; + use crate::recipe::{Execution, SubRecipe}; use serde_json::json; use serde_json::Value; use tempfile::TempDir; @@ -11,23 +11,15 @@ mod tests { let sub_recipe = SubRecipe { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), - timeout_in_seconds: None, values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), executions: None, }; sub_recipe } - fn create_execution_values(key: &str, values: Vec) -> Execution { - let runs = values - .iter() - .map(|value| ExecutionRun { - values: Some(HashMap::from([(key.to_string(), value.to_string())])), - }) - .collect(); + fn create_execution() -> Execution { Execution { parallel: true, - runs: Some(runs), } } @@ -183,10 +175,7 @@ mod tests { prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_THREE_PARAMS); sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - sub_recipe.executions = Some(create_execution_values( - "key2", - vec!["key2_value_1".to_string(), "key2_value_2".to_string()], - )); + sub_recipe.executions = Some(create_execution()); let result = get_input_schema(&sub_recipe).unwrap(); @@ -210,10 +199,7 @@ mod tests { ("key1".to_string(), "value1".to_string()), ("key3".to_string(), "value3".to_string()), ])); - sub_recipe.executions = Some(create_execution_values( - "key2", - vec!["key2_value_1".to_string(), "key2_value_2".to_string()], - )); + sub_recipe.executions = Some(create_execution()); let result = get_input_schema(&sub_recipe).unwrap(); diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 62054cca18e6..40e57f2bd6a5 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -62,10 +62,6 @@ Pre-created Task Based: "default": "sub_recipe", "description": "the type of task to execute, can be one of: sub_recipe, text_instruction" }, - "timeout_in_seconds": { - "type": "number", - "description": "timeout in seconds for the task." - }, "payload": { "type": "object", "properties": { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs index ea31746032d7..b1385f375cc7 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs @@ -18,7 +18,6 @@ pub enum ExecutionMode { pub struct Task { pub id: String, pub task_type: String, - pub timeout_in_seconds: Option, pub payload: Value, } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index fb41b0632d18..a3ab4140de9d 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -1,60 +1,29 @@ use serde_json::Value; use std::process::Stdio; use std::sync::Arc; -use std::time::Duration; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; -use tokio::time::timeout; use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskResult, TaskStatus}; -const DEFAULT_TASK_TIMEOUT_SECONDS: u64 = 300; - pub async fn process_task( task: &Task, task_execution_tracker: Arc, ) -> TaskResult { - let timeout_in_seconds = task - .timeout_in_seconds - .unwrap_or(DEFAULT_TASK_TIMEOUT_SECONDS); - let task_clone = task.clone(); - let timeout_duration = Duration::from_secs(timeout_in_seconds); - - let task_execution_tracker_clone = task_execution_tracker.clone(); - match timeout( - timeout_duration, - get_task_result(task_clone, task_execution_tracker), - ) - .await - { - Ok(Ok(data)) => TaskResult { + match get_task_result(task.clone(), task_execution_tracker).await { + Ok(data) => TaskResult { task_id: task.id.clone(), status: TaskStatus::Completed, data: Some(data), error: None, }, - Ok(Err(error)) => TaskResult { + Err(error) => TaskResult { task_id: task.id.clone(), status: TaskStatus::Failed, data: None, error: Some(error), }, - Err(_) => { - let current_output = task_execution_tracker_clone - .get_current_output(&task.id) - .await - .unwrap_or_default(); - - TaskResult { - task_id: task.id.clone(), - status: TaskStatus::Failed, - data: Some(serde_json::json!({ - "partial_output": current_output - })), - error: Some(format!("Task timed out after {}s", timeout_in_seconds)), - } - } } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs index f799b699aaca..de5bac92fcd8 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -22,7 +22,6 @@ mod test_get_task_name { let sub_recipe_task = Task { id: "task_1".to_string(), task_type: "sub_recipe".to_string(), - timeout_in_seconds: None, payload: json!({ "sub_recipe": { "name": "my_recipe", @@ -40,7 +39,6 @@ mod test_get_task_name { fn falls_back_to_task_id_for_text_instruction() { let text_task = Task { id: "task_2".to_string(), - timeout_in_seconds: None, task_type: "text_instruction".to_string(), payload: json!({"text_instruction": "do something"}), }; @@ -55,7 +53,6 @@ mod test_get_task_name { let malformed_task = Task { id: "task_3".to_string(), task_type: "sub_recipe".to_string(), - timeout_in_seconds: None, payload: json!({ "sub_recipe": { "recipe_path": "/path/to/recipe" @@ -74,7 +71,6 @@ mod test_get_task_name { let malformed_task = Task { id: "task_4".to_string(), task_type: "sub_recipe".to_string(), - timeout_in_seconds: None, payload: json!({}), // missing "sub_recipe" field }; @@ -91,7 +87,6 @@ mod count_by_status { let task = Task { id: id.to_string(), task_type: "test".to_string(), - timeout_in_seconds: None, payload: json!({}), }; create_task_info_with_defaults(task, status) diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 525433221ef2..98ecbbbe9cd7 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -135,7 +135,6 @@ pub struct Response { pub struct SubRecipe { pub name: String, pub path: String, - pub timeout_in_seconds: Option, #[serde(default, deserialize_with = "deserialize_value_map_as_string")] pub values: Option>, #[serde(skip_serializing_if = "Option::is_none")] @@ -145,15 +144,8 @@ pub struct SubRecipe { pub struct Execution { #[serde(default)] pub parallel: bool, - #[serde(skip_serializing_if = "Option::is_none")] - pub runs: Option>, } -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct ExecutionRun { - #[serde(default, deserialize_with = "deserialize_value_map_as_string")] - pub values: Option>, -} fn deserialize_value_map_as_string<'de, D>( deserializer: D, From 64c6240283be01ff6626fe3cf1fb3ae28a5b9ee7 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 16 Jul 2025 13:17:07 +1000 Subject: [PATCH 30/34] renamed variable --- crates/goose-cli/src/session/mod.rs | 4 ++-- test_parallel_execution.yaml | 18 ++++++++++++++++++ weather_sub.yaml | 12 ++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 test_parallel_execution.yaml create mode 100644 weather_sub.yaml diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 21fb5bd576ee..d051b9b85e06 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -1009,7 +1009,7 @@ impl Session { match method.as_str() { "notifications/message" => { let data = o.get("data").unwrap_or(&Value::Null); - let (formatted_message, subagent_id, _notification_type) = match data { + let (formatted_message, subagent_id, message_notification_type) = match data { Value::String(s) => (s.clone(), None, None), Value::Object(o) => { // Check for subagent notification structure first @@ -1080,7 +1080,7 @@ impl Session { } else { progress_bars.log(&formatted_message); } - } else if let Some(ref notification_type) = _notification_type { + } else if let Some(ref notification_type) = message_notification_type { if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE { if interactive { let _ = progress_bars.hide(); diff --git a/test_parallel_execution.yaml b/test_parallel_execution.yaml new file mode 100644 index 000000000000..4a63b2ff5f2d --- /dev/null +++ b/test_parallel_execution.yaml @@ -0,0 +1,18 @@ +version: "1.0.0" +title: "Test Parallel Execution" +description: "Test that executions.parallel=true is respected even with sequential prompt" + +parameters: + - key: city + input_type: string + requirement: required + description: "City name" + +sub_recipes: + - name: weather + path: "./weather_sub.yaml" + executions: + parallel: true + +prompt: | + Get weather for Sydney, Melbourne, and Brisbane in sequential order \ No newline at end of file diff --git a/weather_sub.yaml b/weather_sub.yaml new file mode 100644 index 000000000000..c39c3d377821 --- /dev/null +++ b/weather_sub.yaml @@ -0,0 +1,12 @@ +version: "1.0.0" +title: "Weather Sub-Recipe" +description: "Get weather for a city" + +parameters: + - key: city + input_type: string + requirement: required + description: "City name" + +prompt: | + Get weather information for {{ city }} \ No newline at end of file From 3c7238c12f2fe0a15c78e9c8296d01645ddba532 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 16 Jul 2025 15:00:43 +1000 Subject: [PATCH 31/34] used task ids instead task payload as input schema for executor tool --- crates/goose/src/agents/agent.rs | 15 +- .../agents/recipe_tools/param_utils/tests.rs | 12 +- .../agents/recipe_tools/sub_recipe_tools.rs | 23 +- .../recipe_tools/sub_recipe_tools/tests.rs | 223 +++++------------- .../sub_recipe_execution_tool/lib/mod.rs | 25 +- .../agents/sub_recipe_execution_tool/mod.rs | 1 + .../sub_recipe_execute_task_tool.rs | 65 ++--- .../tasks_manager.rs | 86 +++++++ crates/goose/src/agents/sub_recipe_manager.rs | 11 +- crates/goose/src/recipe/mod.rs | 1 - test_parallel_execution.yaml | 18 -- weather_sub.yaml | 12 - 12 files changed, 229 insertions(+), 263 deletions(-) create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/tasks_manager.rs delete mode 100644 test_parallel_execution.yaml delete mode 100644 weather_sub.yaml diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 3adee02309a0..cb99da418eff 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -12,6 +12,7 @@ use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_ use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{ self, SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, }; +use crate::agents::sub_recipe_execution_tool::tasks_manager::TasksManager; use crate::agents::sub_recipe_manager::SubRecipeManager; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; use crate::message::{push_message, Message}; @@ -63,6 +64,7 @@ pub struct Agent { pub(super) provider: Mutex>>, pub(super) extension_manager: RwLock, pub(super) sub_recipe_manager: Mutex, + pub(super) tasks_manager: TasksManager, pub(super) final_output_tool: Mutex>, pub(super) frontend_tools: Mutex>, pub(super) frontend_instructions: Mutex>, @@ -137,6 +139,7 @@ impl Agent { provider: Mutex::new(None), extension_manager: RwLock::new(ExtensionManager::new()), sub_recipe_manager: Mutex::new(SubRecipeManager::new()), + tasks_manager: TasksManager::new(), final_output_tool: Mutex::new(None), frontend_tools: Mutex::new(HashMap::new()), frontend_instructions: Mutex::new(None), @@ -291,10 +294,18 @@ impl Agent { let sub_recipe_manager = self.sub_recipe_manager.lock().await; let result: ToolCallResult = if sub_recipe_manager.is_sub_recipe_tool(&tool_call.name) { sub_recipe_manager - .dispatch_sub_recipe_tool_call(&tool_call.name, tool_call.arguments.clone()) + .dispatch_sub_recipe_tool_call( + &tool_call.name, + tool_call.arguments.clone(), + &self.tasks_manager, + ) .await } else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME { - sub_recipe_execute_task_tool::run_tasks(tool_call.arguments.clone()).await + sub_recipe_execute_task_tool::run_tasks( + tool_call.arguments.clone(), + &self.tasks_manager, + ) + .await } else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { // Check if the tool is read_resource and handle it separately ToolCallResult::from( diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs index d6a0182c091b..04687c106f5c 100644 --- a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -110,8 +110,14 @@ mod prepare_command_params_tests { #[test] fn test_merge_base_values_with_tool_parameters() { let parameter_array = vec![ - json!(HashMap::from([("key2".to_string(), "override_value1".to_string())])), - json!(HashMap::from([("key2".to_string(), "override_value2".to_string())])), + json!(HashMap::from([( + "key2".to_string(), + "override_value1".to_string() + )])), + json!(HashMap::from([( + "key2".to_string(), + "override_value2".to_string() + )])), ]; let mut sub_recipe = setup_default_sub_recipe(); sub_recipe.values = Some(HashMap::from([ @@ -135,4 +141,4 @@ mod prepare_command_params_tests { ); } } -} \ No newline at end of file +} diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index a9beb1355266..809532c2e994 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -6,6 +6,7 @@ use mcp_core::tool::{Tool, ToolAnnotations}; use serde_json::{json, Map, Value}; use crate::agents::sub_recipe_execution_tool::lib::{ExecutionMode, Task}; +use crate::agents::sub_recipe_execution_tool::tasks_manager::TasksManager; use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubRecipe}; use super::param_utils::prepare_command_params; @@ -23,7 +24,7 @@ pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { - For multiple tasks: provide an array with multiple parameter sets, each with different values\n\n\ Each task will run the same sub recipe but with different parameter values. \ This is useful when you need to execute the same sub recipe multiple times with varying inputs. \ - After creating the task list, pass it to the task executor to run all tasks.", + After creating the task list with execution_mode, pass them to the task executor to run all tasks.", sub_recipe.name ), input_schema, @@ -49,7 +50,7 @@ fn create_tasks_from_params( sub_recipe: &SubRecipe, command_params: &[std::collections::HashMap], ) -> Vec { - command_params + let tasks: Vec = command_params .iter() .map(|task_command_param| { let payload = json!({ @@ -65,7 +66,9 @@ fn create_tasks_from_params( payload, } }) - .collect() + .collect(); + + tasks } fn get_execution_mode(sub_recipe: &SubRecipe) -> ExecutionMode { @@ -82,22 +85,28 @@ fn get_execution_mode(sub_recipe: &SubRecipe) -> ExecutionMode { } } -fn create_task_execution_payload(tasks: Vec, execution_mode: ExecutionMode) -> Value { +fn create_task_execution_payload(tasks: &[Task], execution_mode: ExecutionMode) -> Value { + let task_ids: Vec = tasks.iter().map(|task| task.id.clone()).collect(); json!({ - "tasks": tasks, + "task_ids": task_ids, "execution_mode": execution_mode }) } -pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result { +pub async fn create_sub_recipe_task( + sub_recipe: &SubRecipe, + params: Value, + tasks_manager: &TasksManager, +) -> Result { let task_params_array = extract_task_parameters(¶ms); let command_params = prepare_command_params(sub_recipe, task_params_array.clone())?; let tasks = create_tasks_from_params(sub_recipe, &command_params); let execution_mode = get_execution_mode(sub_recipe); - let task_execution_payload = create_task_execution_payload(tasks, execution_mode); + let task_execution_payload = create_task_execution_payload(&tasks, execution_mode); let tasks_json = serde_json::to_string(&task_execution_payload) .map_err(|e| anyhow::anyhow!("Failed to serialize task list: {}", e))?; + tasks_manager.save_tasks(tasks.clone()).await; Ok(tasks_json) } diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index 21b4d1710fd2..69f347833ee0 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -2,7 +2,7 @@ mod tests { use std::collections::HashMap; - use crate::recipe::{Execution, SubRecipe}; + use crate::recipe::SubRecipe; use serde_json::json; use serde_json::Value; use tempfile::TempDir; @@ -17,12 +17,6 @@ mod tests { sub_recipe } - fn create_execution() -> Execution { - Execution { - parallel: true, - } - } - mod get_input_schema { use super::*; use crate::agents::recipe_tools::sub_recipe_tools::get_input_schema; @@ -50,10 +44,7 @@ mod tests { assert_eq!(&expected_task_parameters_items, task_parameters_items); } - mod without_execution_runs { - use super::*; - - const SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS: &str = r#"{ + const SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS: &str = r#"{ "version": "1.0.0", "title": "Test Recipe", "description": "A test recipe", @@ -74,167 +65,67 @@ mod tests { ] }"#; - #[test] - fn test_with_one_param_in_tool_input() { - let (mut sub_recipe, _temp_dir) = - prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - - let result = get_input_schema(&sub_recipe).unwrap(); - - verify_task_parameters( - result, - json!({ - "type": "object", - "properties": { - "key2": { "type": "number", "description": "An optional parameter" } - }, - "required": [] - }), - ); - } + #[test] + fn test_with_one_param_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); + sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - #[test] - fn test_without_param_in_tool_input() { - let (mut sub_recipe, _temp_dir) = - prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); - sub_recipe.values = Some(HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()), - ])); + let result = get_input_schema(&sub_recipe).unwrap(); - let result = get_input_schema(&sub_recipe).unwrap(); - - assert_eq!( - None, - result - .get("properties") - .unwrap() - .as_object() - .unwrap() - .get("task_parameters") - ); - } - - #[test] - fn test_with_all_params_in_tool_input() { - let (mut sub_recipe, _temp_dir) = - prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); - sub_recipe.values = None; - - let result = get_input_schema(&sub_recipe).unwrap(); - - verify_task_parameters( - result, - json!({ - "type": "object", - "properties": { - "key1": { "type": "string", "description": "A test parameter" }, - "key2": { "type": "number", "description": "An optional parameter" } - }, - "required": ["key1"] - }), - ); - } + verify_task_parameters( + result, + json!({ + "type": "object", + "properties": { + "key2": { "type": "number", "description": "An optional parameter" } + }, + "required": [] + }), + ); } - mod execution_runs { - use super::*; + #[test] + fn test_without_param_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])); + + let result = get_input_schema(&sub_recipe).unwrap(); + + assert_eq!( + None, + result + .get("properties") + .unwrap() + .as_object() + .unwrap() + .get("task_parameters") + ); + } - const SUB_RECIPE_FILE_CONTENT_WITH_THREE_PARAMS: &str = r#"{ - "version": "1.0.0", - "title": "Test Recipe", - "description": "A test recipe", - "prompt": "Test prompt", - "parameters": [ - { - "key": "key1", - "input_type": "string", - "requirement": "required", - "description": "A required string parameter" + #[test] + fn test_with_all_params_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); + sub_recipe.values = None; + + let result = get_input_schema(&sub_recipe).unwrap(); + + verify_task_parameters( + result, + json!({ + "type": "object", + "properties": { + "key1": { "type": "string", "description": "A test parameter" }, + "key2": { "type": "number", "description": "An optional parameter" } }, - { - "key": "key2", - "input_type": "number", - "requirement": "optional", - "description": "An optional parameter" - }, - { - "key": "key3", - "input_type": "date", - "requirement": "required", - "description": "A required date parameter" - } - ] - }"#; - - #[test] - fn test_with_one_param_in_tool_input() { - let (mut sub_recipe, _temp_dir) = - prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_THREE_PARAMS); - sub_recipe.values = - Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - sub_recipe.executions = Some(create_execution()); - - let result = get_input_schema(&sub_recipe).unwrap(); - - verify_task_parameters( - result, - json!({ - "type": "object", - "properties": { - "key3": { "type": "date", "description": "A required date parameter" } - }, - "required": ["key3"] - }), - ); - } - - #[test] - fn test_without_param_in_tool_input() { - let (mut sub_recipe, _temp_dir) = - prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_THREE_PARAMS); - sub_recipe.values = Some(HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key3".to_string(), "value3".to_string()), - ])); - sub_recipe.executions = Some(create_execution()); - - let result = get_input_schema(&sub_recipe).unwrap(); - - assert_eq!( - None, - result - .get("properties") - .unwrap() - .as_object() - .unwrap() - .get("task_parameters") - ); - } - - #[test] - fn test_with_all_params_in_tool_input() { - let (mut sub_recipe, _temp_dir) = - prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_THREE_PARAMS); - sub_recipe.values = None; - - let result = get_input_schema(&sub_recipe).unwrap(); - - verify_task_parameters( - result, - json!({ - "type": "object", - "properties": { - "key1": { "type": "string", "description": "A required string parameter" }, - "key2": { "type": "number", "description": "An optional parameter" }, - "key3": { "type": "date", "description": "A required date parameter" } - }, - "required": ["key1", "key3"] - }), - ); - } + "required": ["key1"] + }), + ); } } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs index 746e4ecb151c..a80b72593e23 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs @@ -4,6 +4,7 @@ use crate::agents::sub_recipe_execution_tool::executor::{ pub use crate::agents::sub_recipe_execution_tool::task_types::{ ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, }; +use crate::agents::sub_recipe_execution_tool::tasks_manager::TasksManager; #[cfg(test)] mod tests; @@ -16,10 +17,28 @@ pub async fn execute_tasks( input: Value, execution_mode: ExecutionMode, notifier: mpsc::Sender, + tasks_manager: &TasksManager, ) -> Result { - let tasks: Vec = - serde_json::from_value(input.get("tasks").ok_or("Missing tasks field")?.clone()) - .map_err(|e| format!("Failed to parse tasks: {}", e))?; + let task_ids: Vec = serde_json::from_value( + input + .get("task_ids") + .ok_or("Missing task_ids field")? + .clone(), + ) + .map_err(|e| format!("Failed to parse task_ids: {}", e))?; + + let mut tasks = Vec::new(); + for task_id in &task_ids { + match tasks_manager.get_task(task_id).await { + Some(task) => tasks.push(task), + None => { + return Err(format!( + "Task with ID '{}' not found in TasksManager", + task_id + )) + } + } + } let task_count = tasks.len(); match execution_mode { diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs index b6363ba20d14..0b7af3b5b644 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs @@ -5,5 +5,6 @@ pub mod sub_recipe_execute_task_tool; mod task_execution_tracker; mod task_types; mod tasks; +pub mod tasks_manager; pub mod utils; mod workers; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 40e57f2bd6a5..1133a2cbbac4 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -3,7 +3,8 @@ use serde_json::Value; use crate::agents::{ sub_recipe_execution_tool::lib::execute_tasks, - sub_recipe_execution_tool::task_types::ExecutionMode, tool_execution::ToolCallResult, + sub_recipe_execution_tool::task_types::ExecutionMode, + sub_recipe_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult, }; use mcp_core::protocol::JsonRpcMessage; use tokio::sync::mpsc; @@ -47,55 +48,15 @@ Pre-created Task Based: "default": "sequential", "description": "Execution strategy for multiple tasks. For pre-created tasks, respect the execution_mode from task creation. For user intent, use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'." }, - "tasks": { + "task_ids": { "type": "array", "items": { - "type": "object", - "properties": { - "id": { - "type": "string", - "description": "Unique identifier for the task" - }, - "task_type": { - "type": "string", - "enum": ["sub_recipe", "text_instruction"], - "default": "sub_recipe", - "description": "the type of task to execute, can be one of: sub_recipe, text_instruction" - }, - "payload": { - "type": "object", - "properties": { - "sub_recipe": { - "type": "object", - "description": "sub recipe to execute", - "properties": { - "name": { - "type": "string", - "description": "name of the sub recipe to execute" - }, - "recipe_path": { - "type": "string", - "description": "path of the sub recipe file" - }, - "command_parameters": { - "type": "object", - "description": "parameters to pass to run recipe command with sub recipe file" - } - } - }, - "text_instruction": { - "type": "string", - "description": "text instruction to execute" - } - } - } - }, - "required": ["id", "payload"] - }, - "description": "The tasks to run in parallel" + "type": "string", + "description": "Unique identifier for the task" + } } }, - "required": ["tasks"] + "required": ["task_ids"] }), Some(ToolAnnotations { title: Some("Run tasks in parallel".to_string()), @@ -107,9 +68,10 @@ Pre-created Task Based: ) } -pub async fn run_tasks(execute_data: Value) -> ToolCallResult { +pub async fn run_tasks(execute_data: Value, tasks_manager: &TasksManager) -> ToolCallResult { let (notification_tx, notification_rx) = mpsc::channel::(100); + let tasks_manager_clone = tasks_manager.clone(); let result_future = async move { let execute_data_clone = execute_data.clone(); let execution_mode = execute_data_clone @@ -117,7 +79,14 @@ pub async fn run_tasks(execute_data: Value) -> ToolCallResult { .and_then(|v| serde_json::from_value::(v.clone()).ok()) .unwrap_or_default(); - match execute_tasks(execute_data, execution_mode, notification_tx).await { + match execute_tasks( + execute_data, + execution_mode, + notification_tx, + &tasks_manager_clone, + ) + .await + { Ok(result) => { let output = serde_json::to_string(&result).unwrap(); Ok(vec![Content::text(output)]) diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks_manager.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks_manager.rs new file mode 100644 index 000000000000..433478be1f29 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks_manager.rs @@ -0,0 +1,86 @@ +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; + +use crate::agents::sub_recipe_execution_tool::task_types::Task; + +#[derive(Debug, Clone)] +pub struct TasksManager { + tasks: Arc>>, +} + +impl Default for TasksManager { + fn default() -> Self { + Self::new() + } +} + +impl TasksManager { + pub fn new() -> Self { + Self { + tasks: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub async fn save_tasks(&self, tasks: Vec) { + let mut task_map = self.tasks.write().await; + for task in tasks { + task_map.insert(task.id.clone(), task); + } + } + + pub async fn get_task(&self, task_id: &str) -> Option { + let tasks = self.tasks.read().await; + tasks.get(task_id).cloned() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn create_test_task(id: &str, sub_recipe_name: &str) -> Task { + Task { + id: id.to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({ + "sub_recipe": { + "name": sub_recipe_name, + "command_parameters": {}, + "recipe_path": "/test/path" + } + }), + } + } + + #[tokio::test] + async fn test_save_and_get_task() { + let manager = TasksManager::new(); + let tasks = vec![create_test_task("task1", "weather")]; + + manager.save_tasks(tasks).await; + + let retrieved = manager.get_task("task1").await; + assert!(retrieved.is_some()); + assert_eq!(retrieved.unwrap().id, "task1"); + } + + #[tokio::test] + async fn test_save_multiple_tasks() { + let manager = TasksManager::new(); + let tasks = vec![ + create_test_task("task1", "weather"), + create_test_task("task2", "news"), + ]; + + manager.save_tasks(tasks).await; + + let task1 = manager.get_task("task1").await; + let task2 = manager.get_task("task2").await; + assert!(task1.is_some()); + assert!(task2.is_some()); + assert_eq!(task1.unwrap().id, "task1"); + assert_eq!(task2.unwrap().id, "task2"); + } +} diff --git a/crates/goose/src/agents/sub_recipe_manager.rs b/crates/goose/src/agents/sub_recipe_manager.rs index cb01c3ffe4dc..33229b97ecef 100644 --- a/crates/goose/src/agents/sub_recipe_manager.rs +++ b/crates/goose/src/agents/sub_recipe_manager.rs @@ -7,6 +7,7 @@ use crate::{ recipe_tools::sub_recipe_tools::{ create_sub_recipe_task, create_sub_recipe_task_tool, SUB_RECIPE_TASK_TOOL_NAME_PREFIX, }, + sub_recipe_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult, }, recipe::SubRecipe, @@ -53,8 +54,11 @@ impl SubRecipeManager { &self, tool_name: &str, params: Value, + tasks_manager: &TasksManager, ) -> ToolCallResult { - let result = self.call_sub_recipe_tool(tool_name, params).await; + let result = self + .call_sub_recipe_tool(tool_name, params, tasks_manager) + .await; match result { Ok(call_result) => ToolCallResult::from(Ok(call_result)), Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))), @@ -65,6 +69,7 @@ impl SubRecipeManager { &self, tool_name: &str, params: Value, + tasks_manager: &TasksManager, ) -> Result, ToolError> { let sub_recipe = self.sub_recipes.get(tool_name).ok_or_else(|| { let sub_recipe_name = tool_name @@ -80,10 +85,10 @@ impl SubRecipeManager { ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name)) })?; - let output = create_sub_recipe_task(sub_recipe, params) + let output = create_sub_recipe_task(sub_recipe, params, tasks_manager) .await .map_err(|e| { - ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e)) + ToolError::ExecutionError(format!("Sub-recipe task createion failed: {}", e)) })?; Ok(vec![Content::text(output)]) } diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 98ecbbbe9cd7..ca5f87766198 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -146,7 +146,6 @@ pub struct Execution { pub parallel: bool, } - fn deserialize_value_map_as_string<'de, D>( deserializer: D, ) -> Result>, D::Error> diff --git a/test_parallel_execution.yaml b/test_parallel_execution.yaml deleted file mode 100644 index 4a63b2ff5f2d..000000000000 --- a/test_parallel_execution.yaml +++ /dev/null @@ -1,18 +0,0 @@ -version: "1.0.0" -title: "Test Parallel Execution" -description: "Test that executions.parallel=true is respected even with sequential prompt" - -parameters: - - key: city - input_type: string - requirement: required - description: "City name" - -sub_recipes: - - name: weather - path: "./weather_sub.yaml" - executions: - parallel: true - -prompt: | - Get weather for Sydney, Melbourne, and Brisbane in sequential order \ No newline at end of file diff --git a/weather_sub.yaml b/weather_sub.yaml deleted file mode 100644 index c39c3d377821..000000000000 --- a/weather_sub.yaml +++ /dev/null @@ -1,12 +0,0 @@ -version: "1.0.0" -title: "Weather Sub-Recipe" -description: "Get weather for a city" - -parameters: - - key: city - input_type: string - requirement: required - description: "City name" - -prompt: | - Get weather information for {{ city }} \ No newline at end of file From 3d8782bcc0b2bbd0661ae4cc7502dec0b8f8f9cc Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 16 Jul 2025 16:46:17 +1000 Subject: [PATCH 32/34] change parallel to sequential_when_repeated --- crates/goose-cli/src/recipes/extract_from_cli.rs | 2 +- .../src/agents/recipe_tools/param_utils/tests.rs | 1 - .../src/agents/recipe_tools/sub_recipe_tools.rs | 12 +++--------- .../agents/recipe_tools/sub_recipe_tools/tests.rs | 1 - crates/goose/src/recipe/mod.rs | 4 ++-- 5 files changed, 6 insertions(+), 14 deletions(-) diff --git a/crates/goose-cli/src/recipes/extract_from_cli.rs b/crates/goose-cli/src/recipes/extract_from_cli.rs index 034dc40de361..56113b75dacd 100644 --- a/crates/goose-cli/src/recipes/extract_from_cli.rs +++ b/crates/goose-cli/src/recipes/extract_from_cli.rs @@ -32,7 +32,7 @@ pub fn extract_recipe_info_from_cli( path: recipe_file_path.to_string_lossy().to_string(), name, values: None, - executions: None, + sequential_when_repeated: true, }; all_sub_recipes.push(additional_sub_recipe); } diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs index 04687c106f5c..ee5e43a77447 100644 --- a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -10,7 +10,6 @@ fn setup_default_sub_recipe() -> SubRecipe { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), - executions: None, }; sub_recipe } diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 809532c2e994..52db04366584 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -72,16 +72,10 @@ fn create_tasks_from_params( } fn get_execution_mode(sub_recipe: &SubRecipe) -> ExecutionMode { - let is_parallel = sub_recipe - .executions - .as_ref() - .map(|e| e.parallel) - .unwrap_or(false); - - if is_parallel { - ExecutionMode::Parallel - } else { + if sub_recipe.sequential_when_repeated { ExecutionMode::Sequential + } else { + ExecutionMode::Parallel } } diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index 69f347833ee0..5327c611e4d7 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -12,7 +12,6 @@ mod tests { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), - executions: None, }; sub_recipe } diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index ca5f87766198..135cdd322dd2 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -137,8 +137,8 @@ pub struct SubRecipe { pub path: String, #[serde(default, deserialize_with = "deserialize_value_map_as_string")] pub values: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub executions: Option, + #[serde(default)] + pub sequential_when_repeated: bool, } #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Execution { From 990e9ab8330ed774db5907a2d5cde881e68d069c Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 16 Jul 2025 20:18:04 +1000 Subject: [PATCH 33/34] make sure same sub recipe with different params to run sequentially when sequential_when_repeated is true --- .../agents/recipe_tools/param_utils/tests.rs | 95 +++++++++---------- .../agents/recipe_tools/sub_recipe_tools.rs | 18 +--- .../recipe_tools/sub_recipe_tools/tests.rs | 1 + .../sub_recipe_execution_tool/lib/mod.rs | 18 +++- .../sub_recipe_execute_task_tool.rs | 19 +--- .../task_execution_tracker.rs | 20 +++- .../sub_recipe_execution_tool/task_types.rs | 6 ++ 7 files changed, 96 insertions(+), 81 deletions(-) diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs index ee5e43a77447..583338d644a7 100644 --- a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -10,6 +10,7 @@ fn setup_default_sub_recipe() -> SubRecipe { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), + sequential_when_repeated: true, }; sub_recipe } @@ -17,59 +18,55 @@ fn setup_default_sub_recipe() -> SubRecipe { mod prepare_command_params_tests { use super::*; - mod without_execution_runs { - use super::*; - - #[test] - fn test_return_command_param() { - let parameter_array = vec![json!(HashMap::from([( - "key2".to_string(), - "value2".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); - - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()) - ]),], - result - ); - } - - #[test] - fn test_return_command_param_when_value_override_passed_param_value() { - let parameter_array = vec![json!(HashMap::from([( - "key2".to_string(), - "different_value".to_string() - )]))]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = Some(HashMap::from([ + #[test] + fn test_return_command_param() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "value2".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()), - ])); + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!( - vec![HashMap::from([ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()) - ]),], - result - ); - } + #[test] + fn test_return_command_param_when_value_override_passed_param_value() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "different_value".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } - #[test] - fn test_return_empty_command_param() { - let parameter_array = vec![]; - let mut sub_recipe = setup_default_sub_recipe(); - sub_recipe.values = None; + #[test] + fn test_return_empty_command_param() { + let parameter_array = vec![]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; - let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); - assert_eq!(result, vec![HashMap::new()]); - } + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!(result, vec![HashMap::new()]); } mod multiple_tool_parameters { diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 52db04366584..810c4a60ae59 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -24,7 +24,7 @@ pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { - For multiple tasks: provide an array with multiple parameter sets, each with different values\n\n\ Each task will run the same sub recipe but with different parameter values. \ This is useful when you need to execute the same sub recipe multiple times with varying inputs. \ - After creating the task list with execution_mode, pass them to the task executor to run all tasks.", + After creating the tasks and execution_mode is provided, pass them to the task executor to run these tasks", sub_recipe.name ), input_schema, @@ -58,6 +58,7 @@ fn create_tasks_from_params( "name": sub_recipe.name.clone(), "command_parameters": task_command_param, "recipe_path": sub_recipe.path.clone(), + "sequential_when_repeated": sub_recipe.sequential_when_repeated } }); Task { @@ -71,19 +72,11 @@ fn create_tasks_from_params( tasks } -fn get_execution_mode(sub_recipe: &SubRecipe) -> ExecutionMode { - if sub_recipe.sequential_when_repeated { - ExecutionMode::Sequential - } else { - ExecutionMode::Parallel - } -} - -fn create_task_execution_payload(tasks: &[Task], execution_mode: ExecutionMode) -> Value { +fn create_task_execution_payload(tasks: &[Task], sub_recipe: &SubRecipe) -> Value { let task_ids: Vec = tasks.iter().map(|task| task.id.clone()).collect(); json!({ "task_ids": task_ids, - "execution_mode": execution_mode + "execution_mode": if sub_recipe.sequential_when_repeated { ExecutionMode::Sequential } else { ExecutionMode::Parallel }, }) } @@ -95,8 +88,7 @@ pub async fn create_sub_recipe_task( let task_params_array = extract_task_parameters(¶ms); let command_params = prepare_command_params(sub_recipe, task_params_array.clone())?; let tasks = create_tasks_from_params(sub_recipe, &command_params); - let execution_mode = get_execution_mode(sub_recipe); - let task_execution_payload = create_task_execution_payload(&tasks, execution_mode); + let task_execution_payload = create_task_execution_payload(&tasks, sub_recipe); let tasks_json = serde_json::to_string(&task_execution_payload) .map_err(|e| anyhow::anyhow!("Failed to serialize task list: {}", e))?; diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index 5327c611e4d7..0b682b0b649b 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -12,6 +12,7 @@ mod tests { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), + sequential_when_repeated: true, }; sub_recipe } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs index a80b72593e23..446b6011c32f 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs @@ -10,7 +10,7 @@ use crate::agents::sub_recipe_execution_tool::tasks_manager::TasksManager; mod tests; use mcp_core::protocol::JsonRpcMessage; -use serde_json::Value; +use serde_json::{json, Value}; use tokio::sync::mpsc; pub async fn execute_tasks( @@ -50,9 +50,21 @@ pub async fn execute_tasks( Err("Sequential execution mode requires exactly one task".to_string()) } } + ExecutionMode::Parallel => { - let response: ExecutionResponse = execute_tasks_in_parallel(tasks, notifier).await; - handle_response(response) + if tasks.iter().any(|task| task.get_sequential_when_repeated()) { + Ok(json!( + { + "execution_mode": ExecutionMode::Sequential, + "task_ids": task_ids, + "results": ["the tasks should be executed sequentially, no matter how user requests it. Please use the subrecipe__execute_task tool to execute the tasks sequentially."] + } + )) + } else { + let response: ExecutionResponse = + execute_tasks_in_parallel(tasks, notifier.clone()).await; + handle_response(response) + } } } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 1133a2cbbac4..e5f9062f398d 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -16,29 +16,18 @@ pub fn create_sub_recipe_execute_task_tool() -> Tool { SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, "Only use this tool when you execute sub recipe task. EXECUTION STRATEGY DECISION: -1. PRE-CREATED TASKS: If tasks were created by subrecipe__create_task_* tools, check the execution_mode in the response: - - If execution_mode is 'parallel', use parallel execution - - If execution_mode is 'sequential', use sequential execution - - Always respect the execution_mode from task creation to maintain consistency - -2. USER INTENT: If creating tasks inline or user explicitly specifies: - - DEFAULT: Execute tasks sequentially unless user explicitly requests parallel execution - - PARALLEL: When user uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently' +1. If the tasks are created with execution_mode, use the execution_mode. +2. Execute tasks sequentially unless user explicitly requests parallel execution. PARALLEL: User uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently' IMPLEMENTATION: - Sequential execution: Call this tool multiple times, passing exactly ONE task per call - Parallel execution: Call this tool once, passing an ARRAY of all tasks EXAMPLES: -User Intent Based: - User: 'get weather and tell me a joke' → Sequential (2 separate tool calls, 1 task each) - User: 'get weather and joke in parallel' → Parallel (1 tool call with array of 2 tasks) - User: 'run these simultaneously' → Parallel (1 tool call with task array) -- User: 'do task A then task B' → Sequential (2 separate tool calls) - -Pre-created Task Based: -- subrecipe__create_task_weather returns execution_mode: 'parallel' → Use parallel execution -- subrecipe__create_task_weather returns execution_mode: 'sequential' → Use sequential execution", +- User: 'do task A then task B' → Sequential (2 separate tool calls)", serde_json::json!({ "type": "object", "properties": { @@ -46,7 +35,7 @@ Pre-created Task Based: "type": "string", "enum": ["sequential", "parallel"], "default": "sequential", - "description": "Execution strategy for multiple tasks. For pre-created tasks, respect the execution_mode from task creation. For user intent, use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'." + "description": "Execution strategy for multiple tasks. Use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'." }, "task_ids": { "type": "array", diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs index b456fd77424f..a906a59a755f 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -116,9 +116,27 @@ impl TaskExecutionTracker { pub async fn send_live_output(&self, task_id: &str, line: &str) { match self.display_mode { DisplayMode::SingleTaskOutput => { + let tasks = self.tasks.read().await; + let task_info = tasks.get(task_id); + + let formatted_line = if let Some(task_info) = task_info { + let task_name = get_task_name(task_info); + let task_type = task_info.task.task_type.clone(); + let metadata = format_task_metadata(task_info); + + if metadata.is_empty() { + format!("[{} ({})] {}", task_name, task_type, line) + } else { + format!("[{} ({}) {}] {}", task_name, task_type, metadata, line) + } + } else { + line.to_string() + }; + drop(tasks); + let event = TaskExecutionNotificationEvent::line_output( task_id.to_string(), - line.to_string(), + formatted_line, ); if let Err(e) = diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs index b1385f375cc7..4515bb8420af 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs @@ -28,6 +28,12 @@ impl Task { .flatten() } + pub fn get_sequential_when_repeated(&self) -> bool { + self.get_sub_recipe() + .and_then(|sr| sr.get("sequential_when_repeated").and_then(|v| v.as_bool())) + .unwrap_or_default() + } + pub fn get_command_parameters(&self) -> Option<&Map> { self.get_sub_recipe() .and_then(|sr| sr.get("command_parameters")) From fad41f89bee96403311e38411cecb11c38567f35 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 16 Jul 2025 22:01:32 +1000 Subject: [PATCH 34/34] apply character-boundary-safe slicing --- crates/goose-cli/src/session/task_execution_display/mod.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/crates/goose-cli/src/session/task_execution_display/mod.rs b/crates/goose-cli/src/session/task_execution_display/mod.rs index 96d37b76d483..ec6c41ff201b 100644 --- a/crates/goose-cli/src/session/task_execution_display/mod.rs +++ b/crates/goose-cli/src/session/task_execution_display/mod.rs @@ -51,7 +51,11 @@ fn process_output_for_display(output: &str) -> String { fn truncate_with_ellipsis(text: &str, max_len: usize) -> String { if text.len() > max_len { - format!("{}...", &text[..max_len.saturating_sub(3)]) + let mut end = max_len.saturating_sub(3); + while end > 0 && !text.is_char_boundary(end) { + end -= 1; + } + format!("{}...", &text[..end]) } else { text.to_string() }