Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b60f02a
add working_dir to session metadata, update session route
salman1993 Mar 6, 2025
75e3173
goose changes in ui/desktop
salman1993 Mar 6, 2025
bfcc365
code compiles; added working dir in goose-cli
salman1993 Mar 6, 2025
70ba6ba
fmt, clippy
salman1993 Mar 6, 2025
14bf748
session working dir opens window, chat msgs dont show up
salman1993 Mar 6, 2025
480cb30
Robust approach to handle resuming sessions with URL params
salman1993 Mar 6, 2025
0b9536d
warn user about resuming session in different working dir in CLI
salman1993 Mar 6, 2025
f9251a6
warning in yellow color
salman1993 Mar 6, 2025
42e7437
show msg to chdir to working_dir in CLI by default
salman1993 Mar 6, 2025
bc7c7b7
fmt, minor fixes
salman1993 Mar 6, 2025
3a3b5b1
fix sessions storing working dir from CLI
salman1993 Mar 6, 2025
398db10
fix: resumedSession is undefined if we don't wait to load
salman1993 Mar 7, 2025
6743384
Merge branch 'main' into sm/sessions-workdir
salman1993 Mar 7, 2025
dc75143
remove log that keeps repeating
salman1993 Mar 7, 2025
256fe0b
raise error when trying to update session if file doesnt exist
salman1993 Mar 7, 2025
25ef711
temp stash changes
salman1993 Mar 7, 2025
6482aad
session working dir is working as expected
salman1993 Mar 7, 2025
be4935b
fmt, clippy
salman1993 Mar 7, 2025
344fcb3
remove update session metadata route
salman1993 Mar 7, 2025
b2e7d15
minor improvements
salman1993 Mar 7, 2025
4719051
simplify sessionId const in ChatView since it doesn't change
salman1993 Mar 7, 2025
d9873af
open new session in GOOSE_WORKING_DIR
salman1993 Mar 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 43 additions & 24 deletions crates/goose-cli/src/session/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,6 @@ pub async fn build_session(
let mut agent = AgentFactory::create(&AgentFactory::configured_version(), provider)
.expect("Failed to create agent");

// Setup extensions for the agent
for extension in ExtensionManager::get_all().expect("should load extensions") {
if extension.enabled {
let config = extension.config.clone();
agent
.add_extension(config.clone())
.await
.unwrap_or_else(|e| {
let err = match e {
ExtensionError::Transport(McpClientError::StdioProcessError(inner)) => {
inner
}
_ => e.to_string(),
};
println!("Failed to start extension: {}, {:?}", config.name(), err);
println!(
"Please check extension configuration for {}.",
config.name()
);
process::exit(1);
});
}
}

// Handle session file resolution and resuming
let session_file = if resume {
if let Some(identifier) = identifier {
Expand All @@ -69,6 +45,7 @@ pub async fn build_session(
));
process::exit(1);
}

session_file
} else {
// Try to resume most recent session
Expand All @@ -91,6 +68,48 @@ pub async fn build_session(
session::get_path(id)
};

if resume {
// Read the session metadata
let metadata = session::read_metadata(&session_file).unwrap_or_else(|e| {
output::render_error(&format!("Failed to read session metadata: {}", e));
process::exit(1);
});

// Ask user if they want to change the working directory
let change_workdir = cliclack::confirm(format!("{} The working directory of this session was set to {}. It does not match the current working directory. Would you like to change it?", style("WARNING:").yellow(), style(metadata.working_dir.display()).cyan()))
.initial_value(true)
.interact().expect("Failed to get user input");

if change_workdir {
std::env::set_current_dir(metadata.working_dir).unwrap();
}
}

// Setup extensions for the agent
// Extensions need to be added after the session is created because we change directory when resuming a session
for extension in ExtensionManager::get_all().expect("should load extensions") {
if extension.enabled {
let config = extension.config.clone();
agent
.add_extension(config.clone())
.await
.unwrap_or_else(|e| {
let err = match e {
ExtensionError::Transport(McpClientError::StdioProcessError(inner)) => {
inner
}
_ => e.to_string(),
};
println!("Failed to start extension: {}, {:?}", config.name(), err);
println!(
"Please check extension configuration for {}.",
config.name()
);
process::exit(1);
});
}
}

// Create new session
let mut session = Session::new(agent, session_file.clone());

Expand Down
15 changes: 12 additions & 3 deletions crates/goose-cli/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use completion::GooseCompleter;
use etcetera::choose_app_strategy;
use etcetera::AppStrategy;
use goose::agents::extension::{Envs, ExtensionConfig};
use goose::agents::Agent;
use goose::agents::{Agent, SessionConfig};
use goose::config::Config;
use goose::message::{Message, MessageContent};
use goose::session;
Expand Down Expand Up @@ -197,7 +197,6 @@ impl Session {
/// Process a single message and get the response
async fn process_message(&mut self, message: String) -> Result<()> {
self.messages.push(Message::user().with_text(&message));

// Get the provider from the agent for description generation
let provider = self.agent.provider().await;

Expand Down Expand Up @@ -434,7 +433,17 @@ impl Session {

async fn process_agent_response(&mut self, interactive: bool) -> Result<()> {
let session_id = session::Identifier::Path(self.session_file.clone());
let mut stream = self.agent.reply(&self.messages, Some(session_id)).await?;
let mut stream = self
.agent
.reply(
&self.messages,
Some(SessionConfig {
id: session_id,
working_dir: std::env::current_dir()
.expect("failed to get current session working directory"),
}),
)
.await?;

use futures::StreamExt;
loop {
Expand Down
7 changes: 7 additions & 0 deletions crates/goose-cli/src/session/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,13 @@ pub fn display_session_info(resume: bool, provider: &str, model: &str, session_f
style("logging to").dim(),
style(session_file.display()).dim().cyan(),
);
println!(
" {} {}",
style("working directory:").dim(),
style(std::env::current_dir().unwrap().display())
.cyan()
.dim()
);
}

pub fn display_greeting() {
Expand Down
23 changes: 19 additions & 4 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@ use axum::{
};
use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use goose::message::{Message, MessageContent};
use goose::session;
use goose::{
agents::SessionConfig,
message::{Message, MessageContent},
};

use mcp_core::role::Role;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{
convert::Infallible,
path::PathBuf,
pin::Pin,
task::{Context, Poll},
time::Duration,
Expand All @@ -29,6 +33,7 @@ use tokio_stream::wrappers::ReceiverStream;
struct ChatRequest {
messages: Vec<Message>,
session_id: Option<String>,
session_working_dir: String,
}

// Custom SSE response type for streaming messages
Expand Down Expand Up @@ -108,8 +113,8 @@ async fn handler(
let (tx, rx) = mpsc::channel(100);
let stream = ReceiverStream::new(rx);

// Get messages directly from the request
let messages = request.messages;
let session_working_dir = request.session_working_dir;

// Generate a new session ID if not provided in the request
let session_id = request
Expand Down Expand Up @@ -149,7 +154,10 @@ async fn handler(
let mut stream = match agent
.reply(
&messages,
Some(session::Identifier::Name(session_id.clone())),
Some(SessionConfig {
id: session::Identifier::Name(session_id.clone()),
working_dir: PathBuf::from(session_working_dir),
}),
)
.await
{
Expand Down Expand Up @@ -246,6 +254,7 @@ async fn handler(
struct AskRequest {
prompt: String,
session_id: Option<String>,
session_working_dir: String,
}

#[derive(Debug, Serialize)]
Expand All @@ -269,6 +278,8 @@ async fn ask_handler(
return Err(StatusCode::UNAUTHORIZED);
}

let session_working_dir = request.session_working_dir;

// Generate a new session ID if not provided in the request
let session_id = request
.session_id
Expand All @@ -289,7 +300,10 @@ async fn ask_handler(
let mut stream = match agent
.reply(
&messages,
Some(session::Identifier::Name(session_id.clone())),
Some(SessionConfig {
id: session::Identifier::Name(session_id.clone()),
working_dir: PathBuf::from(session_working_dir),
}),
)
.await
{
Expand Down Expand Up @@ -464,6 +478,7 @@ mod tests {
serde_json::to_string(&AskRequest {
prompt: "test prompt".to_string(),
session_id: Some("test-session".to_string()),
session_working_dir: "test-working-dir".to_string(),
})
.unwrap(),
))
Expand Down
13 changes: 12 additions & 1 deletion crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::collections::HashMap;
use std::path::PathBuf;

use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;

Expand All @@ -13,14 +15,23 @@ use crate::session;
use mcp_core::prompt::Prompt;
use mcp_core::protocol::GetPromptResult;

/// Session configuration for an agent
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionConfig {
/// Unique identifier for the session
pub id: session::Identifier,
/// Working directory for the session
pub working_dir: PathBuf,
}

/// Core trait defining the behavior of an Agent
#[async_trait]
pub trait Agent: Send + Sync {
/// Create a stream that yields each message as it's generated by the agent
async fn reply(
&self,
messages: &[Message],
session_id: Option<session::Identifier>,
session: Option<SessionConfig>,
) -> Result<BoxStream<'_, Result<Message>>>;

/// Add a new MCP client to the agent
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/agents/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod reference;
mod summarize;
mod truncate;

pub use agent::Agent;
pub use agent::{Agent, SessionConfig};
pub use capabilities::Capabilities;
pub use extension::ExtensionConfig;
pub use factory::{register_agent, AgentFactory};
Expand Down
10 changes: 6 additions & 4 deletions crates/goose/src/agents/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, instrument};

use super::agent::SessionConfig;
use super::Agent;
use crate::agents::capabilities::Capabilities;
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
Expand Down Expand Up @@ -70,11 +71,11 @@ impl Agent for ReferenceAgent {
// TODO implement
}

#[instrument(skip(self, messages), fields(user_message))]
#[instrument(skip(self, messages, session), fields(user_message))]
async fn reply(
&self,
messages: &[Message],
session_id: Option<session::Identifier>,
session: Option<SessionConfig>,
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
let mut messages = messages.to_vec();
let reply_span = tracing::Span::current();
Expand Down Expand Up @@ -148,10 +149,11 @@ impl Agent for ReferenceAgent {
capabilities.record_usage(usage.clone()).await;

// record usage for the session in the session file
if let Some(session_id) = session_id.clone() {
if let Some(session) = session.clone() {
// TODO: track session_id in langfuse tracing
let session_file = session::get_path(session_id);
let session_file = session::get_path(session.id);
let mut metadata = session::read_metadata(&session_file)?;
metadata.working_dir = session.working_dir;
metadata.total_tokens = usage.usage.total_tokens;
// The message count is the number of messages in the session + 1 for the response
// The message count does not include the tool response till next iteration
Expand Down
10 changes: 6 additions & 4 deletions crates/goose/src/agents/summarize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tracing::{debug, error, instrument, warn};

use super::agent::SessionConfig;
use super::detect_read_only_tools;
use super::Agent;
use crate::agents::capabilities::Capabilities;
Expand Down Expand Up @@ -162,11 +163,11 @@ impl Agent for SummarizeAgent {
}
}

#[instrument(skip(self, messages), fields(user_message))]
#[instrument(skip(self, messages, session), fields(user_message))]
async fn reply(
&self,
messages: &[Message],
session_id: Option<session::Identifier>,
session: Option<SessionConfig>,
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
let mut messages = messages.to_vec();
let reply_span = tracing::Span::current();
Expand Down Expand Up @@ -246,10 +247,11 @@ impl Agent for SummarizeAgent {
capabilities.record_usage(usage.clone()).await;

// record usage for the session in the session file
if let Some(session_id) = session_id.clone() {
if let Some(session) = session.clone() {
// TODO: track session_id in langfuse tracing
let session_file = session::get_path(session_id);
let session_file = session::get_path(session.id);
let mut metadata = session::read_metadata(&session_file)?;
metadata.working_dir = session.working_dir;
metadata.total_tokens = usage.usage.total_tokens;
// The message count is the number of messages in the session + 1 for the response
// The message count does not include the tool response till next iteration
Expand Down
10 changes: 6 additions & 4 deletions crates/goose/src/agents/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tracing::{debug, error, instrument, warn};

use super::agent::SessionConfig;
use super::detect_read_only_tools;
use super::Agent;
use crate::agents::capabilities::Capabilities;
Expand Down Expand Up @@ -145,11 +146,11 @@ impl Agent for TruncateAgent {
}
}

#[instrument(skip(self, messages, session_id), fields(user_message))]
#[instrument(skip(self, messages, session), fields(user_message))]
async fn reply(
&self,
messages: &[Message],
session_id: Option<session::Identifier>,
session: Option<SessionConfig>,
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
let mut messages = messages.to_vec();
let reply_span = tracing::Span::current();
Expand Down Expand Up @@ -229,10 +230,11 @@ impl Agent for TruncateAgent {
capabilities.record_usage(usage.clone()).await;

// record usage for the session in the session file
if let Some(session_id) = session_id.clone() {
if let Some(session) = session.clone() {
// TODO: track session_id in langfuse tracing
let session_file = session::get_path(session_id);
let session_file = session::get_path(session.id);
let mut metadata = session::read_metadata(&session_file)?;
metadata.working_dir = session.working_dir;
metadata.total_tokens = usage.usage.total_tokens;
// The message count is the number of messages in the session + 1 for the response
// The message count does not include the tool response till next iteration
Expand Down
Loading
Loading