Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 80 additions & 45 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ use super::platform_tools;
use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE};
use crate::agents::subagent_task_config::TaskConfig;
use crate::agents::todo_tools::{
// todo_read_tool, todo_write_tool, // TODO: Re-enable after next release
TODO_READ_TOOL_NAME,
TODO_WRITE_TOOL_NAME,
todo_read_tool, todo_write_tool, TODO_READ_TOOL_NAME, TODO_WRITE_TOOL_NAME,
};
use crate::conversation::message::{Message, ToolRequest};

Expand Down Expand Up @@ -103,7 +101,6 @@ pub struct Agent {
pub(super) tool_route_manager: ToolRouteManager,
pub(super) scheduler_service: Mutex<Option<Arc<dyn SchedulerTrait>>>,
pub(super) retry_manager: RetryManager,
pub(super) todo_list: Arc<Mutex<String>>,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -155,15 +152,6 @@ where
}

impl Agent {
const DEFAULT_TODO_MAX_CHARS: usize = 50_000;

fn get_todo_max_chars() -> usize {
std::env::var("GOOSE_TODO_MAX_CHARS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(Self::DEFAULT_TODO_MAX_CHARS)
}

pub fn new() -> Self {
// Create channels with buffer size 32 (adjust if needed)
let (confirm_tx, confirm_rx) = mpsc::channel(32);
Expand All @@ -189,7 +177,6 @@ impl Agent {
tool_route_manager: ToolRouteManager::new(),
scheduler_service: Mutex::new(None),
retry_manager,
todo_list: Arc::new(Mutex::new(String::new())),
}
}

Expand Down Expand Up @@ -292,14 +279,20 @@ impl Agent {
permission_check_result: &PermissionCheckResult,
message_tool_response: Arc<Mutex<Message>>,
cancel_token: Option<tokio_util::sync::CancellationToken>,
session: &Option<SessionConfig>,
) -> Result<Vec<(String, ToolStream)>> {
let mut tool_futures: Vec<(String, ToolStream)> = Vec::new();

// Handle pre-approved and read-only tools
for request in &permission_check_result.approved {
if let Ok(tool_call) = request.tool_call.clone() {
let (req_id, tool_result) = self
.dispatch_tool_call(tool_call, request.id.clone(), cancel_token.clone())
.dispatch_tool_call(
tool_call,
request.id.clone(),
cancel_token.clone(),
session,
)
.await;

tool_futures.push((
Expand Down Expand Up @@ -379,6 +372,7 @@ impl Agent {
tool_call: mcp_core::tool::ToolCall,
request_id: String,
cancellation_token: Option<CancellationToken>,
session: &Option<SessionConfig>,
) -> (String, Result<ToolCallResult, ErrorData>) {
// Check if this tool call should be allowed based on repetition monitoring
if let Some(monitor) = self.tool_monitor.lock().await.as_mut() {
Expand Down Expand Up @@ -491,7 +485,21 @@ impl Agent {
)))
} else if tool_call.name == TODO_READ_TOOL_NAME {
// Handle task planner read tool
let todo_content = self.todo_list.lock().await.clone();
let session_file_path = if let Some(session_config) = session {
session::storage::get_path(session_config.id.clone()).ok()
} else {
None
};

let todo_content = if let Some(path) = session_file_path {
session::storage::read_metadata(&path)
.ok()
.and_then(|m| m.todo_content)
.unwrap_or_default()
} else {
String::new()
};

ToolCallResult::from(Ok(vec![Content::text(todo_content)]))
} else if tool_call.name == TODO_WRITE_TOOL_NAME {
// Handle task planner write tool
Expand All @@ -502,34 +510,66 @@ impl Agent {
.unwrap_or("")
.to_string();

// Acquire lock first to prevent race condition
let mut todo_list = self.todo_list.lock().await;

// Character limit validation
let char_count = content.chars().count();
let max_chars = Self::get_todo_max_chars();
let max_chars = std::env::var("GOOSE_TODO_MAX_CHARS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(50_000);

// Simple validation - reject if over limit (0 means unlimited)
if max_chars > 0 && char_count > max_chars {
return (
request_id,
Ok(ToolCallResult::from(Err(ErrorData::new(
ToolCallResult::from(Err(ErrorData::new(
ErrorCode::INTERNAL_ERROR,
format!(
"Todo list too large: {} chars (max: {})",
char_count, max_chars
),
None,
)))
} else if let Some(session_config) = session {
// Update session metadata with new TODO content
match session::storage::get_path(session_config.id.clone()) {
Ok(path) => match session::storage::read_metadata(&path) {
Ok(mut metadata) => {
metadata.todo_content = Some(content);
let path_clone = path.clone();
let metadata_clone = metadata.clone();
let update_result = tokio::task::spawn(async move {
session::storage::update_metadata(&path_clone, &metadata_clone)
.await
})
.await;

match update_result {
Ok(Ok(_)) => ToolCallResult::from(Ok(vec![Content::text(
format!("Updated ({} chars)", char_count),
)])),
_ => ToolCallResult::from(Err(ErrorData::new(
ErrorCode::INTERNAL_ERROR,
"Failed to update session metadata".to_string(),
None,
))),
}
}
Err(_) => ToolCallResult::from(Err(ErrorData::new(
ErrorCode::INTERNAL_ERROR,
"Failed to read session metadata".to_string(),
None,
))),
},
Err(_) => ToolCallResult::from(Err(ErrorData::new(
ErrorCode::INTERNAL_ERROR,
format!(
"Todo list too large: {} chars (max: {})",
char_count, max_chars
),
"Failed to get session path".to_string(),
None,
)))),
);
))),
}
} else {
ToolCallResult::from(Err(ErrorData::new(
ErrorCode::INTERNAL_ERROR,
"TODO tools require an active session to persist data".to_string(),
None,
)))
}

*todo_list = content;

ToolCallResult::from(Ok(vec![Content::text(format!(
"Updated ({} chars)",
char_count
))]))
} else if tool_call.name == ROUTER_LLM_SEARCH_TOOL_NAME {
match self
.tool_route_manager
Expand Down Expand Up @@ -756,8 +796,7 @@ impl Agent {
]);

// Add task planner tools
// TODO: Re-enable after next release
// prefixed_tools.extend([todo_read_tool(), todo_write_tool()]);
prefixed_tools.extend([todo_read_tool(), todo_write_tool()]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


// Dynamic task tool
prefixed_tools.push(create_dynamic_task_tool());
Expand Down Expand Up @@ -1090,7 +1129,8 @@ impl Agent {
let mut tool_futures = self.handle_approved_and_denied_tools(
&permission_check_result,
message_tool_response.clone(),
cancel_token.clone()
cancel_token.clone(),
&session
).await?;

let tool_futures_arc = Arc::new(Mutex::new(tool_futures));
Expand Down Expand Up @@ -1508,7 +1548,6 @@ mod tests {
}

#[tokio::test]
#[ignore] // TODO: Re-enable after next release when TODO tools are re-enabled
async fn test_todo_tools_integration() -> Result<()> {
let agent = Agent::new();

Expand All @@ -1521,10 +1560,6 @@ mod tests {
assert!(todo_read.is_some(), "TODO read tool should be present");
assert!(todo_write.is_some(), "TODO write tool should be present");

// Test todo_list initialization
let todo_content = agent.todo_list.lock().await;
assert_eq!(*todo_content, "", "TODO list should be initially empty");

Ok(())
}
}
2 changes: 1 addition & 1 deletion crates/goose/src/agents/todo_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ pub fn todo_write_tool() -> Tool {
}

#[cfg(test)]
mod tests {
mod unit_tests {
use super::*;

#[test]
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/agents/tool_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl Agent {
while let Some((req_id, confirmation)) = rx.recv().await {
if req_id == request.id {
if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow {
let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone(), cancellation_token.clone()).await;
let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone(), cancellation_token.clone(), &None).await;
let mut futures = tool_futures.lock().await;

futures.push((req_id, match tool_result {
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/context_mgmt/auto_compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ mod tests {
accumulated_total_tokens: Some(100),
accumulated_input_tokens: Some(50),
accumulated_output_tokens: Some(50),
todo_content: None,
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,7 @@ async fn run_scheduled_job_internal(
accumulated_total_tokens: None,
accumulated_input_tokens: None,
accumulated_output_tokens: None,
todo_content: None,
};
if let Err(e_fb) = crate::session::storage::save_messages_with_metadata(
&session_file_path,
Expand Down
7 changes: 6 additions & 1 deletion crates/goose/src/session/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ pub struct SessionMetadata {
pub accumulated_input_tokens: Option<i32>,
/// The number of output tokens used in the session. Accumulated across all messages.
pub accumulated_output_tokens: Option<i32>,
/// Session-scoped TODO list content
pub todo_content: Option<String>,
}

// Custom deserializer to handle old sessions without working_dir
// Custom deserializer to handle old sessions without working_dir and todo_content
impl<'de> Deserialize<'de> for SessionMetadata {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
Expand All @@ -84,6 +86,7 @@ impl<'de> Deserialize<'de> for SessionMetadata {
accumulated_input_tokens: Option<i32>,
accumulated_output_tokens: Option<i32>,
working_dir: Option<PathBuf>,
todo_content: Option<String>, // For backward compatibility
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming this lets us load old sessions from before this feature?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

}

let helper = Helper::deserialize(deserializer)?;
Expand All @@ -105,6 +108,7 @@ impl<'de> Deserialize<'de> for SessionMetadata {
accumulated_input_tokens: helper.accumulated_input_tokens,
accumulated_output_tokens: helper.accumulated_output_tokens,
working_dir,
todo_content: helper.todo_content,
})
}
}
Expand All @@ -129,6 +133,7 @@ impl SessionMetadata {
accumulated_total_tokens: None,
accumulated_input_tokens: None,
accumulated_output_tokens: None,
todo_content: None,
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/tests/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ mod final_output_tool_tests {
}),
);
let (_, result) = agent
.dispatch_tool_call(tool_call, "request_id".to_string(), None)
.dispatch_tool_call(tool_call, "request_id".to_string(), None, &None)
.await;

assert!(result.is_ok(), "Tool call should succeed");
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/tests/private_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@ async fn test_schedule_tool_dispatch() {
};

let (request_id, result) = agent
.dispatch_tool_call(tool_call, "test_dispatch".to_string(), None)
.dispatch_tool_call(tool_call, "test_dispatch".to_string(), None, &None)
.await;
assert_eq!(request_id, "test_dispatch");
assert!(result.is_ok());
Expand Down
1 change: 1 addition & 0 deletions crates/goose/tests/test_support.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,5 +411,6 @@ pub fn create_test_session_metadata(message_count: usize, working_dir: &str) ->
accumulated_total_tokens: Some(100),
accumulated_input_tokens: Some(50),
accumulated_output_tokens: Some(50),
todo_content: None,
}
}
Loading
Loading