Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,35 @@ fn create_tasks_from_params(
sub_recipe: &SubRecipe,
command_params: &[std::collections::HashMap<String, String>],
) -> Vec<Task> {
// Extract task_timeout from sub_recipe values if present
let task_timeout = sub_recipe
.values
.as_ref()
.and_then(|values| values.get("task_timeout"))
.and_then(|timeout_str| timeout_str.parse::<u64>().ok());

let tasks: Vec<Task> = command_params
.iter()
.map(|task_command_param| {
let mut sub_recipe_data = json!({
"name": sub_recipe.name.clone(),
"command_parameters": task_command_param,
"recipe_path": sub_recipe.path.clone(),
"sequential_when_repeated": sub_recipe.sequential_when_repeated
});

// Add task_timeout to the payload if present
if let Some(timeout_seconds) = task_timeout {
sub_recipe_data.as_object_mut().unwrap().insert(
"task_timeout".to_string(),
json!(timeout_seconds)
);
}

let payload = json!({
"sub_recipe": {
"name": sub_recipe.name.clone(),
"command_parameters": task_command_param,
"recipe_path": sub_recipe.path.clone(),
"sequential_when_repeated": sub_recipe.sequential_when_repeated
}
"sub_recipe": sub_recipe_data
});

Task {
id: uuid::Uuid::new_v4().to_string(),
task_type: "sub_recipe".to_string(),
Expand Down
91 changes: 91 additions & 0 deletions crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,95 @@ mod tests {
);
}
}

mod create_tasks_from_params {
use super::*;
use crate::agents::recipe_tools::sub_recipe_tools::create_tasks_from_params;

#[test]
fn test_creates_tasks_without_timeout() {
let sub_recipe = setup_default_sub_recipe();
let command_params = vec![
HashMap::from([("param1".to_string(), "value1".to_string())]),
HashMap::from([("param2".to_string(), "value2".to_string())]),
];

let tasks = create_tasks_from_params(&sub_recipe, &command_params);

assert_eq!(tasks.len(), 2);
for (i, task) in tasks.iter().enumerate() {
assert_eq!(task.task_type, "sub_recipe");
let payload = task.payload.as_object().unwrap();
let sub_recipe_obj = payload.get("sub_recipe").unwrap().as_object().unwrap();
assert_eq!(sub_recipe_obj.get("name").unwrap(), "test_sub_recipe");
assert_eq!(sub_recipe_obj.get("recipe_path").unwrap(), "test_sub_recipe.yaml");
assert_eq!(sub_recipe_obj.get("sequential_when_repeated").unwrap(), true);
assert_eq!(sub_recipe_obj.get("command_parameters").unwrap(), &json!(command_params[i]));
assert!(sub_recipe_obj.get("task_timeout").is_none());
}
}

#[test]
fn test_creates_tasks_with_timeout() {
let mut sub_recipe = setup_default_sub_recipe();
sub_recipe.values = Some(HashMap::from([
("key1".to_string(), "value1".to_string()),
("task_timeout".to_string(), "3600".to_string()),
]));
let command_params = vec![
HashMap::from([("param1".to_string(), "value1".to_string())]),
];

let tasks = create_tasks_from_params(&sub_recipe, &command_params);

assert_eq!(tasks.len(), 1);
let task = &tasks[0];
assert_eq!(task.task_type, "sub_recipe");
let payload = task.payload.as_object().unwrap();
let sub_recipe_obj = payload.get("sub_recipe").unwrap().as_object().unwrap();
assert_eq!(sub_recipe_obj.get("task_timeout").unwrap(), 3600);
}

#[test]
fn test_ignores_invalid_timeout() {
let mut sub_recipe = setup_default_sub_recipe();
sub_recipe.values = Some(HashMap::from([
("key1".to_string(), "value1".to_string()),
("task_timeout".to_string(), "not_a_number".to_string()),
]));
let command_params = vec![
HashMap::from([("param1".to_string(), "value1".to_string())]),
];

let tasks = create_tasks_from_params(&sub_recipe, &command_params);

assert_eq!(tasks.len(), 1);
let task = &tasks[0];
let payload = task.payload.as_object().unwrap();
let sub_recipe_obj = payload.get("sub_recipe").unwrap().as_object().unwrap();
assert!(sub_recipe_obj.get("task_timeout").is_none());
}

#[test]
fn test_multiple_tasks_with_same_timeout() {
let mut sub_recipe = setup_default_sub_recipe();
sub_recipe.values = Some(HashMap::from([
("task_timeout".to_string(), "7200".to_string()),
]));
let command_params = vec![
HashMap::from([("param1".to_string(), "value1".to_string())]),
HashMap::from([("param2".to_string(), "value2".to_string())]),
HashMap::from([("param3".to_string(), "value3".to_string())]),
];

let tasks = create_tasks_from_params(&sub_recipe, &command_params);

assert_eq!(tasks.len(), 3);
for task in tasks.iter() {
let payload = task.payload.as_object().unwrap();
let sub_recipe_obj = payload.get("sub_recipe").unwrap().as_object().unwrap();
assert_eq!(sub_recipe_obj.get("task_timeout").unwrap(), 7200);
}
}
}
}
17 changes: 17 additions & 0 deletions crates/goose/src/agents/subagent_execution_tool/task_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ impl Task {
None
}
}

/// Extract task-specific timeout from the task payload
pub fn get_task_timeout(&self) -> Option<u64> {
// For sub_recipe tasks, get timeout from sub_recipe object
if self.task_type == "sub_recipe" {
self.get_sub_recipe()
.and_then(|sr| sr.get("task_timeout"))
.and_then(|timeout| timeout.as_u64())
} else {
// For text_instruction tasks, check if there's a sub_recipe field with timeout
self.payload
.get("sub_recipe")
.and_then(|sr| sr.as_object())
.and_then(|sr| sr.get("task_timeout"))
.and_then(|timeout| timeout.as_u64())
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down
Loading