Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions crates/goose-cli/src/commands/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,25 +89,18 @@ pub async fn handle_session_remove(
regex_string: Option<String>,
) -> Result<()> {
let session_manager = SessionManager::instance();
let all_sessions = match session_manager.list_sessions().await {
Ok(sessions) => sessions,
Err(e) => {
tracing::error!("Failed to retrieve sessions: {:?}", e);
return Err(anyhow::anyhow!("Failed to retrieve sessions"));
}
};

let matched_sessions: Vec<Session>;

if let Some(id_val) = session_id {
if let Some(session) = all_sessions.iter().find(|s| s.id == id_val) {
matched_sessions = vec![session.clone()];
} else {
return Err(anyhow::anyhow!("Session ID '{}' not found.", id_val));
match session_manager.get_session(&id_val, false).await {
Ok(session) => matched_sessions = vec![session],
Err(_) => return Err(anyhow::anyhow!("Session ID '{}' not found.", id_val)),
}
} else if let Some(name_val) = name {
if let Some(session) = all_sessions.iter().find(|s| s.name == name_val) {
matched_sessions = vec![session.clone()];
let all_sessions = session_manager.list_all_sessions().await?;
if let Some(session) = all_sessions.into_iter().find(|s| s.name == name_val) {
matched_sessions = vec![session];
} else {
return Err(anyhow::anyhow!(
"Session with name '{}' not found.",
Expand All @@ -118,7 +111,8 @@ pub async fn handle_session_remove(
let session_regex = Regex::new(&regex_val)
.with_context(|| format!("Invalid regex pattern '{}'", regex_val))?;

matched_sessions = all_sessions
let visible_sessions = session_manager.list_sessions().await?;
matched_sessions = visible_sessions
.into_iter()
.filter(|session| session_regex.is_match(&session.id))
.collect();
Expand All @@ -128,10 +122,11 @@ pub async fn handle_session_remove(
return Ok(());
}
} else {
if all_sessions.is_empty() {
let visible_sessions = session_manager.list_sessions().await?;
if visible_sessions.is_empty() {
return Err(anyhow::anyhow!("No sessions found."));
}
matched_sessions = prompt_interactive_session_removal(&all_sessions)?;
matched_sessions = prompt_interactive_session_removal(&visible_sessions)?;
}

if matched_sessions.is_empty() {
Expand Down
33 changes: 22 additions & 11 deletions crates/goose/src/session/session_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,11 @@ impl SessionManager {
}

pub async fn list_sessions_by_types(&self, types: &[SessionType]) -> Result<Vec<Session>> {
self.storage.list_sessions_by_types(types).await
self.storage.list_sessions_by_types(Some(types)).await
}

pub async fn list_all_sessions(&self) -> Result<Vec<Session>> {
self.storage.list_sessions_by_types(None).await
}

pub async fn delete_session(&self, id: &str) -> Result<()> {
Expand Down Expand Up @@ -1239,12 +1243,19 @@ impl SessionStorage {
Self::replace_conversation_inner(pool, session_id, conversation).await
}

async fn list_sessions_by_types(&self, types: &[SessionType]) -> Result<Vec<Session>> {
if types.is_empty() {
return Ok(Vec::new());
}
async fn list_sessions_by_types(&self, types: Option<&[SessionType]>) -> Result<Vec<Session>> {
let (where_clause, binds): (String, Vec<String>) = match types {
Some(t) if !t.is_empty() => {
let placeholders: String = t.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
(
format!("WHERE s.session_type IN ({})", placeholders),
t.iter().map(|t| t.to_string()).collect(),
)
}
Some(_) => return Ok(Vec::new()),
None => (String::new(), Vec::new()),
};

let placeholders: String = types.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
let query = format!(
r#"
SELECT s.id, s.working_dir, s.name, s.description, s.user_set_name, s.session_type, s.created_at, s.updated_at, s.extension_data,
Expand All @@ -1255,24 +1266,24 @@ impl SessionStorage {
COUNT(m.id) as message_count
FROM sessions s
INNER JOIN messages m ON s.id = m.session_id
WHERE s.session_type IN ({})
{}
GROUP BY s.id
ORDER BY s.updated_at DESC
"#,
placeholders
where_clause
);

let mut q = sqlx::query_as::<_, Session>(&query);
for t in types {
q = q.bind(t.to_string());
for b in &binds {
q = q.bind(b);
}

let pool = self.pool().await?;
q.fetch_all(pool).await.map_err(Into::into)
}

async fn list_sessions(&self) -> Result<Vec<Session>> {
self.list_sessions_by_types(&[SessionType::User, SessionType::Scheduled])
self.list_sessions_by_types(Some(&[SessionType::User, SessionType::Scheduled]))
.await
}

Expand Down
Loading