diff --git a/crates/goose-cli/Cargo.toml b/crates/goose-cli/Cargo.toml index 23dcc934d18d..07e20f748e9b 100644 --- a/crates/goose-cli/Cargo.toml +++ b/crates/goose-cli/Cargo.toml @@ -66,6 +66,8 @@ webbrowser = "1.0" indicatif = "0.17.11" urlencoding = "2" +ratatui = "0.30.0-alpha.4" +crossterm = "0.27" [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["wincred"] } diff --git a/crates/goose-cli/src/cli.rs b/crates/goose-cli/src/cli.rs index 6b4ea2e9320c..309223816896 100644 --- a/crates/goose-cli/src/cli.rs +++ b/crates/goose-cli/src/cli.rs @@ -339,9 +339,13 @@ enum Command { value_name = "NAME", help = "Add builtin extensions by name (e.g., 'developer' or multiple: 'developer,github')", long_help = "Add one or more builtin extensions that are bundled with goose by specifying their names, comma-separated", - value_delimiter = ',' + value_delimiter = ',', )] builtins: Vec, + + /// Launch the interactive session using a full-screen TUI (powered by ratatui) + #[arg(long, help = "Use experimental TUI interface instead of line-by-line mode")] + tui: bool, }, /// Open the last project directory @@ -627,6 +631,7 @@ pub async fn cli() -> Result<()> { extensions, remote_extensions, builtins, + tui, }) => { return match command { Some(SessionCommand::List { @@ -659,7 +664,7 @@ pub async fn cli() -> Result<()> { Ok(()) } None => { - // Run session command by default + // Run session command by default (either standard REPL or TUI) let mut session: crate::Session = build_session(SessionBuilderConfig { identifier: identifier.map(extract_identifier), resume, @@ -687,7 +692,11 @@ pub async fn cli() -> Result<()> { session.render_message_history(); } - let _ = session.interactive(None).await; + if tui { + let _ = session.interactive_tui(None).await; + } else { + let _ = session.interactive(None).await; + } Ok(()) } }; diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 18916550629a..6e2930b85b45 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -5,6 +5,7 @@ mod input; mod output; mod prompt; mod thinking; +pub mod tui; pub use self::export::message_to_markdown; pub use builder::{build_session, SessionBuilderConfig, SessionSettings}; @@ -1300,6 +1301,18 @@ impl Session { Ok(path) } + + /// Start an interactive TUI session using `ratatui`. + pub async fn interactive_tui(&mut self, message: Option) -> Result<()> { + // Process initial message if provided so users see the conversation when the UI starts. + if let Some(msg) = message { + self.process_message(msg).await?; + } + // Build and run the TUI. The TUI owns a _mutable reference_ to `self` while running, so we + // need to construct it first and then await its future. + let tui = tui::GooseTui::new(self)?; + tui.run().await + } } fn get_reasoner() -> Result, anyhow::Error> { diff --git a/crates/goose-cli/src/session/tui.rs b/crates/goose-cli/src/session/tui.rs new file mode 100644 index 000000000000..1e514e00b37d --- /dev/null +++ b/crates/goose-cli/src/session/tui.rs @@ -0,0 +1,165 @@ +use anyhow::Result; +use crossterm::{ + event::{self, Event, KeyCode}, + execute, + terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, +}; +use ratatui::{ + backend::CrosstermBackend, + layout::{Constraint, Direction, Layout}, + style::{Color, Style}, + text::{Line, Span}, + widgets::{Block, Borders, Paragraph, Wrap}, + Terminal, +}; +use std::{ + io::{self, Stdout}, + time::Duration, +}; +use mcp_core::role::Role; + +/// Very small abstraction layer so we don't have to expose the whole `ratatui` types to the parent +/// modules. The struct just keeps the terminal alive while the TUI runs. +pub struct GooseTui<'a> { + terminal: Terminal>, + /// Buffer holding the user input while they type + input: String, + /// Shared reference to an interactive [`crate::Session`]. Held as mutable reference so we can + /// push messages and request completions. + session: &'a mut crate::Session, + /// Scroll offset for the chat history panel + scroll: u16, + /// Stores the rendered text for each historical message. We keep things as simple `String`s for + /// now – every line break yields a new line on screen which is good enough for a first cut. + history: Vec<(String, bool /* is_user */)>, +} + +impl<'a> GooseTui<'a> { + pub fn new(session: &'a mut crate::Session) -> Result { + enable_raw_mode()?; + let mut stdout = io::stdout(); + execute!(stdout, EnterAlternateScreen)?; + let backend = CrosstermBackend::new(stdout); + let terminal = Terminal::new(backend)?; + Ok(Self { + terminal, + input: String::new(), + session, + scroll: 0, + history: Vec::new(), + }) + } + + /// Consumes the TUI, restoring the terminal. + fn teardown(&mut self) -> Result<()> { + disable_raw_mode()?; + execute!(self.terminal.backend_mut(), LeaveAlternateScreen)?; + self.terminal.show_cursor()?; + Ok(()) + } + + /// Run the TUI main loop. This will block until the user presses . + pub async fn run(mut self) -> Result<()> { + loop { + // Draw UI + self.terminal.draw(|f| { + let size = f.size(); + + // Split screen into message area + input line (3 rows) + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Min(1), Constraint::Length(3)].as_ref()) + .split(size); + + // Render message history + let history_lines: Vec = self + .history + .iter() + .flat_map(|(line, is_user)| { + let clr = if *is_user { Color::Yellow } else { Color::White }; + line.split('\n') + .map(move |l| { + Line::from(vec![Span::styled( + l.to_owned(), + Style::default().fg(clr), + )]) + }) + .collect::>() + }) + .collect(); + + let history_para = Paragraph::new(history_lines) + .block(Block::default().title("Messages").borders(Borders::ALL)) + .wrap(Wrap { trim: false }); + f.render_widget(history_para, chunks[0]); + + // Render input area + let input_para = Paragraph::new(self.input.as_str()) + .style(Style::default().fg(Color::Cyan)) + .block(Block::default().title("Input (Esc to quit)").borders(Borders::ALL)); + f.render_widget(input_para, chunks[1]); + // Put cursor at end of input buffer + let x = chunks[1].x + (self.input.len() as u16) + 1; + let y = chunks[1].y + 1; + #[allow(deprecated)] + { + f.set_cursor(x, y); + } + })?; + + // Handle events + if event::poll(Duration::from_millis(100))? { + if let Event::Key(key) = event::read()? { + match key.code { + KeyCode::Char(c) => { + self.input.push(c); + } + KeyCode::Backspace => { + self.input.pop(); + } + KeyCode::Enter => { + let user_msg = self.input.trim().to_string(); + if !user_msg.is_empty() { + // Push to local history first so the user gets immediate feedback + self.history.push((format!("You: {}", &user_msg), true)); + + // Clear input buffer before awaiting async call so the UI remains responsive + self.input.clear(); + + // Run the agent interaction synchronously for now (will freeze UI briefly). + if let Err(e) = self.session.process_message(user_msg).await { + self.history.push((format!("Error: {}", e), false)); + } + + // After processing (successful or not), refresh from session's message history. + let new_msgs = self.session.message_history(); + self.history = new_msgs + .iter() + .flat_map(|m| { + let mut lines = Vec::new(); + let sender = match m.role { + Role::User => "You", + Role::Assistant => "Assistant", + }; + let text_concat = m.as_concat_text(); + for l in text_concat.split('\n') { + let is_user = matches!(m.role, Role::User); + lines.push((format!("{}: {}", sender, l), is_user)); + } + lines + }) + .collect(); + } + } + KeyCode::Esc => { + break; + } + _ => {} + } + } + } + } + + self.teardown() + } +} \ No newline at end of file diff --git a/crates/goose/src/agents/tool_vectordb.rs b/crates/goose/src/agents/tool_vectordb.rs index 293360e234c8..73e2f703aba5 100644 --- a/crates/goose/src/agents/tool_vectordb.rs +++ b/crates/goose/src/agents/tool_vectordb.rs @@ -12,6 +12,8 @@ use std::path::PathBuf; use std::sync::Arc; use tokio::sync::RwLock; +use crate::config::Config; + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolRecord { pub tool_name: String, @@ -53,7 +55,25 @@ impl ToolVectorDB { Ok(tool_db) } - fn get_db_path() -> Result { + pub fn get_db_path() -> Result { + let config = Config::global(); + + // Check for custom database path override + if let Ok(custom_path) = config.get_param::("GOOSE_VECTOR_DB_PATH") { + let path = PathBuf::from(custom_path); + + // Validate the path is absolute + if !path.is_absolute() { + return Err(anyhow::anyhow!( + "GOOSE_VECTOR_DB_PATH must be an absolute path, got: {}", + path.display() + )); + } + + return Ok(path); + } + + // Fall back to default XDG-based path let data_dir = Xdg::new() .context("Failed to determine base strategy")? .data_dir(); @@ -363,6 +383,7 @@ mod tests { use super::*; #[tokio::test] + #[serial_test::serial] async fn test_tool_vectordb_creation() { let db = ToolVectorDB::new(Some("test_tools_vectordb_creation".to_string())) .await @@ -372,6 +393,7 @@ mod tests { } #[tokio::test] + #[serial_test::serial] async fn test_tool_vectordb_operations() -> Result<()> { // Create a new database instance with a unique table name let db = ToolVectorDB::new(Some("test_tool_vectordb_operations".to_string())).await?; @@ -440,6 +462,7 @@ mod tests { } #[tokio::test] + #[serial_test::serial] async fn test_empty_db() -> Result<()> { // Create a new database instance with a unique table name let db = ToolVectorDB::new(Some("test_empty_db".to_string())).await?; @@ -458,6 +481,7 @@ mod tests { } #[tokio::test] + #[serial_test::serial] async fn test_tool_deletion() -> Result<()> { // Create a new database instance with a unique table name let db = ToolVectorDB::new(Some("test_tool_deletion".to_string())).await?; @@ -490,4 +514,74 @@ mod tests { Ok(()) } + + #[test] + #[serial_test::serial] + fn test_custom_db_path_override() -> Result<()> { + use std::env; + use tempfile::TempDir; + + // Create a temporary directory for testing + let temp_dir = TempDir::new().unwrap(); + let custom_path = temp_dir.path().join("custom_vector_db"); + + // Set the environment variable + env::set_var("GOOSE_VECTOR_DB_PATH", custom_path.to_str().unwrap()); + + // Test that get_db_path returns the custom path + let db_path = ToolVectorDB::get_db_path()?; + assert_eq!(db_path, custom_path); + + // Clean up + env::remove_var("GOOSE_VECTOR_DB_PATH"); + + Ok(()) + } + + #[test] + #[serial_test::serial] + fn test_custom_db_path_validation() { + use std::env; + + // Test that relative paths are rejected + env::set_var("GOOSE_VECTOR_DB_PATH", "relative/path"); + + let result = ToolVectorDB::get_db_path(); + assert!( + result.is_err(), + "Expected error for relative path, got: {:?}", + result + ); + assert!(result + .unwrap_err() + .to_string() + .contains("must be an absolute path")); + + // Clean up + env::remove_var("GOOSE_VECTOR_DB_PATH"); + } + + #[test] + #[serial_test::serial] + fn test_fallback_to_default_path() -> Result<()> { + use std::env; + + // Ensure no custom path is set + env::remove_var("GOOSE_VECTOR_DB_PATH"); + + // Test that it falls back to default XDG path + let db_path = ToolVectorDB::get_db_path()?; + assert!( + db_path.to_string_lossy().contains("goose"), + "Path should contain 'goose', got: {}", + db_path.display() + ); + assert!( + db_path.to_string_lossy().contains("tool_db"), + "Path should contain 'tool_db', got: {}", + db_path.display() + ); + + Ok(()) + } } diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index 149499f5bde1..3716df0e6dc3 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -151,4 +151,59 @@ impl Provider for GroqProvider { super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) } + + /// Fetch supported models from Groq; returns Err on failure, Ok(None) if no models found + async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { + // Construct the Groq models endpoint + let base_url = url::Url::parse(&self.host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {}", e)))?; + let url = base_url.join("openai/v1/models").map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {}", e)) + })?; + + // Build the request with required headers + let request = self + .client + .get(url) + .bearer_auth(&self.api_key) + .header("Content-Type", "application/json"); + + // Send request + let response = request.send().await?; + let status = response.status(); + let payload: serde_json::Value = response.json().await.map_err(|_| { + ProviderError::RequestFailed("Response body is not valid JSON".to_string()) + })?; + + // Check for error response from API + if let Some(err_obj) = payload.get("error") { + let msg = err_obj + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("unknown error"); + return Err(ProviderError::Authentication(msg.to_string())); + } + + // Extract model names + if status == StatusCode::OK { + let data = payload + .get("data") + .and_then(|v| v.as_array()) + .ok_or_else(|| { + ProviderError::UsageError("Missing or invalid `data` field in response".into()) + })?; + + let mut model_names: Vec = data + .iter() + .filter_map(|m| m.get("id").and_then(Value::as_str).map(String::from)) + .collect(); + model_names.sort(); + Ok(Some(model_names)) + } else { + Err(ProviderError::RequestFailed(format!( + "Groq API returned error status: {}. Payload: {:?}", + status, payload + ))) + } + } } diff --git a/crates/goose/src/session/storage.rs b/crates/goose/src/session/storage.rs index 786b09b4d2e3..77ac9b94bbbc 100644 --- a/crates/goose/src/session/storage.rs +++ b/crates/goose/src/session/storage.rs @@ -1,10 +1,18 @@ +// IMPORTANT: This file includes session recovery functionality to handle corrupted session files. +// Only essential logging is included with the [SESSION] prefix to track: +// - Total message counts +// - Corruption detection and recovery +// - Backup creation +// Additional debug logging can be added if needed for troubleshooting. + use crate::message::Message; use crate::providers::base::Provider; use anyhow::Result; use chrono::Local; use etcetera::{choose_app_strategy, AppStrategy, AppStrategyArgs}; +use regex::Regex; use serde::{Deserialize, Serialize}; -use std::fs::{self, File}; +use std::fs; use std::io::{self, BufRead, Write}; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -207,24 +215,56 @@ pub fn generate_session_id() -> String { Local::now().format("%Y%m%d_%H%M%S").to_string() } -/// Read messages from a session file +/// Read messages from a session file with corruption recovery /// /// Creates the file if it doesn't exist, reads and deserializes all messages if it does. /// The first line of the file is expected to be metadata, and the rest are messages. /// Large messages are automatically truncated to prevent memory issues. +/// Includes recovery mechanisms for corrupted files. pub fn read_messages(session_file: &Path) -> Result> { - read_messages_with_truncation(session_file, Some(50000)) // 50KB limit per message content + let result = read_messages_with_truncation(session_file, Some(50000)); // 50KB limit per message content + match &result { + Ok(messages) => println!( + "[SESSION] Successfully read {} messages from: {:?}", + messages.len(), + session_file + ), + Err(e) => println!( + "[SESSION] Failed to read messages from {:?}: {}", + session_file, e + ), + } + result } -/// Read messages from a session file with optional content truncation +/// Read messages from a session file with optional content truncation and corruption recovery /// /// Creates the file if it doesn't exist, reads and deserializes all messages if it does. /// The first line of the file is expected to be metadata, and the rest are messages. /// If max_content_size is Some, large message content will be truncated during loading. +/// Includes robust error handling and corruption recovery mechanisms. pub fn read_messages_with_truncation( session_file: &Path, max_content_size: Option, ) -> Result> { + // Check if there's a backup file we should restore from + let backup_file = session_file.with_extension("backup"); + if !session_file.exists() && backup_file.exists() { + println!( + "[SESSION] Session file missing but backup exists, restoring from backup: {:?}", + backup_file + ); + tracing::warn!( + "[SESSION] Session file missing but backup exists, restoring from backup: {:?}", + backup_file + ); + if let Err(e) = fs::copy(&backup_file, session_file) { + println!("[SESSION] Failed to restore from backup: {}", e); + tracing::error!("Failed to restore from backup: {}", e); + } + } + + // Open the file with appropriate options let file = fs::OpenOptions::new() .read(true) .write(true) @@ -235,27 +275,137 @@ pub fn read_messages_with_truncation( let reader = io::BufReader::new(file); let mut lines = reader.lines(); let mut messages = Vec::new(); + let mut corrupted_lines = Vec::new(); + let mut line_number = 1; // Read the first line as metadata or create default if empty/missing - if let Some(line) = lines.next() { - let line = line?; - // Try to parse as metadata, but if it fails, treat it as a message - if let Ok(_metadata) = serde_json::from_str::(&line) { - // Metadata successfully parsed, continue with the rest of the lines as messages - } else { - // This is not metadata, it's a message - let message = parse_message_with_truncation(&line, max_content_size)?; - messages.push(message); + if let Some(line_result) = lines.next() { + match line_result { + Ok(line) => { + // Try to parse as metadata, but if it fails, treat it as a message + if let Ok(_metadata) = serde_json::from_str::(&line) { + // Metadata successfully parsed, continue with the rest of the lines as messages + } else { + // This is not metadata, it's a message + match parse_message_with_truncation(&line, max_content_size) { + Ok(message) => { + messages.push(message); + } + Err(e) => { + println!("[SESSION] Failed to parse first line as message: {}", e); + println!("[SESSION] Attempting to recover corrupted first line..."); + tracing::warn!("Failed to parse first line as message: {}", e); + + // Try to recover the corrupted line + match attempt_corruption_recovery(&line, max_content_size) { + Ok(recovered) => { + println!( + "[SESSION] Successfully recovered corrupted first line!" + ); + messages.push(recovered); + } + Err(recovery_err) => { + println!( + "[SESSION] Failed to recover corrupted first line: {}", + recovery_err + ); + corrupted_lines.push((line_number, line)); + } + } + } + } + } + } + Err(e) => { + println!("[SESSION] Failed to read first line: {}", e); + tracing::error!("Failed to read first line: {}", e); + corrupted_lines.push((line_number, "[Unreadable line]".to_string())); + } } + line_number += 1; } // Read the rest of the lines as messages - for line in lines { - let line = line?; - let message = parse_message_with_truncation(&line, max_content_size)?; - messages.push(message); + for line_result in lines { + match line_result { + Ok(line) => match parse_message_with_truncation(&line, max_content_size) { + Ok(message) => { + messages.push(message); + } + Err(e) => { + println!("[SESSION] Failed to parse line {}: {}", line_number, e); + println!( + "[SESSION] Attempting to recover corrupted line {}...", + line_number + ); + tracing::warn!("Failed to parse line {}: {}", line_number, e); + + // Try to recover the corrupted line + match attempt_corruption_recovery(&line, max_content_size) { + Ok(recovered) => { + println!( + "[SESSION] Successfully recovered corrupted line {}!", + line_number + ); + messages.push(recovered); + } + Err(recovery_err) => { + println!( + "[SESSION] Failed to recover corrupted line {}: {}", + line_number, recovery_err + ); + corrupted_lines.push((line_number, line)); + } + } + } + }, + Err(e) => { + println!("[SESSION] Failed to read line {}: {}", line_number, e); + tracing::error!("Failed to read line {}: {}", line_number, e); + corrupted_lines.push((line_number, "[Unreadable line]".to_string())); + } + } + line_number += 1; + } + + // If we found corrupted lines, create a backup and log the issues + if !corrupted_lines.is_empty() { + println!( + "[SESSION] Found {} corrupted lines, creating backup", + corrupted_lines.len() + ); + tracing::warn!( + "[SESSION] Found {} corrupted lines in session file, creating backup", + corrupted_lines.len() + ); + + // Create a backup of the original file + if !backup_file.exists() { + if let Err(e) = fs::copy(session_file, &backup_file) { + println!("[SESSION] Failed to create backup file: {}", e); + tracing::error!("Failed to create backup file: {}", e); + } else { + println!("[SESSION] Created backup file: {:?}", backup_file); + tracing::info!("Created backup file: {:?}", backup_file); + } + } + + // Log details about corrupted lines + for (num, line) in &corrupted_lines { + let preview = if line.len() > 50 { + format!("{}... (truncated)", &line[..50]) + } else { + line.clone() + }; + tracing::debug!("Corrupted line {}: {}", num, preview); + } } + println!( + "[SESSION] Finished reading session file. Total messages: {}, corrupted lines: {}", + messages.len(), + corrupted_lines.len() + ); Ok(messages) } @@ -273,9 +423,13 @@ fn parse_message_with_truncation( } Ok(message) } - Err(e) => { + Err(_e) => { // If parsing fails and the string is very long, it might be due to size if json_str.len() > 100000 { + println!( + "[SESSION] Very large message detected ({}KB), attempting truncation", + json_str.len() / 1024 + ); tracing::warn!( "Failed to parse very large message ({}KB), attempting truncation", json_str.len() / 1024 @@ -290,18 +444,21 @@ fn parse_message_with_truncation( match serde_json::from_str::(&truncated_json) { Ok(message) => { + println!("[SESSION] Successfully parsed message after truncation"); tracing::info!("Successfully parsed message after JSON truncation"); Ok(message) } Err(_) => { - tracing::error!("Failed to parse message even after truncation, skipping"); - // Return a placeholder message indicating the issue - Ok(Message::user() - .with_text("[Message too large to load - content truncated]")) + println!( + "[SESSION] Failed to parse even after truncation, attempting recovery" + ); + tracing::error!("Failed to parse message even after truncation"); + attempt_corruption_recovery(json_str, max_content_size) } } } else { - Err(e.into()) + // Try intelligent corruption recovery + attempt_corruption_recovery(json_str, max_content_size) } } } @@ -365,6 +522,235 @@ fn truncate_message_content_in_place(message: &mut Message, max_content_size: us } } +/// Attempt to recover corrupted JSON lines using various strategies +fn attempt_corruption_recovery(json_str: &str, max_content_size: Option) -> Result { + // Strategy 1: Try to fix common JSON corruption issues + if let Ok(message) = try_fix_json_corruption(json_str, max_content_size) { + println!("[SESSION] Recovered using JSON corruption fix"); + return Ok(message); + } + + // Strategy 2: Try to extract partial content if it looks like a message + if let Ok(message) = try_extract_partial_message(json_str) { + println!("[SESSION] Recovered using partial message extraction"); + return Ok(message); + } + + // Strategy 3: Try to fix truncated JSON + if let Ok(message) = try_fix_truncated_json(json_str, max_content_size) { + println!("[SESSION] Recovered using truncated JSON fix"); + return Ok(message); + } + + // Strategy 4: Create a placeholder message with the raw content + println!("[SESSION] All recovery strategies failed, creating placeholder message"); + let preview = if json_str.len() > 200 { + format!("{}...", &json_str[..200]) + } else { + json_str.to_string() + }; + + Ok(Message::user().with_text(format!( + "[RECOVERED FROM CORRUPTED LINE]\nOriginal content preview: {}\n\n[This message was recovered from a corrupted session file line. The original data may be incomplete.]", + preview + ))) +} + +/// Try to fix common JSON corruption patterns +fn try_fix_json_corruption(json_str: &str, max_content_size: Option) -> Result { + let mut fixed_json = json_str.to_string(); + let mut fixes_applied = Vec::new(); + + // Fix 1: Remove trailing commas before closing braces/brackets + if fixed_json.contains(",}") || fixed_json.contains(",]") { + fixed_json = fixed_json.replace(",}", "}").replace(",]", "]"); + fixes_applied.push("trailing commas"); + } + + // Fix 2: Try to close unclosed quotes in text fields + if let Some(text_start) = fixed_json.find("\"text\":\"") { + let content_start = text_start + 8; + if let Some(remaining) = fixed_json.get(content_start..) { + // Count quotes to see if we have an odd number (unclosed quote) + let quote_count = remaining.matches('"').count(); + if quote_count % 2 == 1 { + // Find the last quote and see if we need to close it + if let Some(last_quote_pos) = remaining.rfind('"') { + let after_last_quote = &remaining[last_quote_pos + 1..]; + if !after_last_quote.trim_start().starts_with(',') + && !after_last_quote.trim_start().starts_with('}') + { + // Insert a closing quote before the next field or end + if let Some(next_field) = after_last_quote.find(',') { + fixed_json.insert(content_start + last_quote_pos + 1 + next_field, '"'); + fixes_applied.push("unclosed quotes"); + } else if after_last_quote.contains('}') { + if let Some(brace_pos) = after_last_quote.find('}') { + fixed_json + .insert(content_start + last_quote_pos + 1 + brace_pos, '"'); + fixes_applied.push("unclosed quotes"); + } + } + } + } + } + } + } + + // Fix 3: Try to close unclosed JSON objects/arrays + let open_braces = fixed_json.matches('{').count(); + let close_braces = fixed_json.matches('}').count(); + let open_brackets = fixed_json.matches('[').count(); + let close_brackets = fixed_json.matches(']').count(); + + if open_braces > close_braces { + for _ in 0..(open_braces - close_braces) { + fixed_json.push('}'); + } + fixes_applied.push("unclosed braces"); + } + + if open_brackets > close_brackets { + for _ in 0..(open_brackets - close_brackets) { + fixed_json.push(']'); + } + fixes_applied.push("unclosed brackets"); + } + + // Fix 4: Remove control characters that might break JSON parsing + let original_len = fixed_json.len(); + fixed_json = fixed_json + .chars() + .filter(|c| !c.is_control() || *c == '\n' || *c == '\r' || *c == '\t') + .collect(); + if fixed_json.len() != original_len { + fixes_applied.push("control characters"); + } + + if !fixes_applied.is_empty() { + println!("[SESSION] Applied JSON fixes: {}", fixes_applied.join(", ")); + + match serde_json::from_str::(&fixed_json) { + Ok(mut message) => { + if let Some(max_size) = max_content_size { + truncate_message_content_in_place(&mut message, max_size); + } + return Ok(message); + } + Err(e) => { + println!("[SESSION] JSON fixes didn't work: {}", e); + } + } + } + + Err(anyhow::anyhow!("JSON corruption fixes failed")) +} + +/// Try to extract a partial message from corrupted JSON +fn try_extract_partial_message(json_str: &str) -> Result { + // Look for recognizable patterns that indicate this was a message + + // Try to extract role + let role = if json_str.contains("\"role\":\"user\"") { + mcp_core::role::Role::User + } else if json_str.contains("\"role\":\"assistant\"") { + mcp_core::role::Role::Assistant + } else { + mcp_core::role::Role::User // Default fallback + }; + + // Try to extract text content + let mut extracted_text = String::new(); + + // Look for text field content + if let Some(text_start) = json_str.find("\"text\":\"") { + let content_start = text_start + 8; + if let Some(content_end) = json_str[content_start..].find("\",") { + extracted_text = json_str[content_start..content_start + content_end].to_string(); + } else if let Some(content_end) = json_str[content_start..].find("\"") { + extracted_text = json_str[content_start..content_start + content_end].to_string(); + } else { + // Take everything after "text":" until we hit a likely end + let remaining = &json_str[content_start..]; + if let Some(end_pos) = remaining.find('}') { + extracted_text = remaining[..end_pos].trim_end_matches('"').to_string(); + } else { + extracted_text = remaining.to_string(); + } + } + } + + // If we couldn't extract text, try to find any readable content + if extracted_text.is_empty() { + // Look for any quoted strings that might be content + let quote_pattern = Regex::new(r#""([^"]{10,})""#).unwrap(); + if let Some(captures) = quote_pattern.find(json_str) { + extracted_text = captures.as_str().trim_matches('"').to_string(); + } + } + + if !extracted_text.is_empty() { + println!( + "[SESSION] Extracted text content: {}", + if extracted_text.len() > 50 { + &extracted_text[..50] + } else { + &extracted_text + } + ); + + let message = match role { + mcp_core::role::Role::User => Message::user(), + mcp_core::role::Role::Assistant => Message::assistant(), + }; + + return Ok(message.with_text(format!("[PARTIALLY RECOVERED] {}", extracted_text))); + } + + Err(anyhow::anyhow!("Could not extract partial message")) +} + +/// Try to fix truncated JSON by completing it +fn try_fix_truncated_json(json_str: &str, max_content_size: Option) -> Result { + let mut completed_json = json_str.to_string(); + + // If the JSON appears to be cut off mid-field, try to complete it + if !completed_json.trim().ends_with('}') && !completed_json.trim().ends_with(']') { + // Try to find where it was likely cut off + if let Some(last_quote) = completed_json.rfind('"') { + let after_quote = &completed_json[last_quote + 1..]; + if !after_quote.contains('"') && !after_quote.contains('}') { + // Looks like it was cut off in the middle of a string value + completed_json.push('"'); + + // Try to close the JSON structure + let open_braces = completed_json.matches('{').count(); + let close_braces = completed_json.matches('}').count(); + + for _ in 0..(open_braces - close_braces) { + completed_json.push('}'); + } + + println!("[SESSION] Attempting to complete truncated JSON"); + + match serde_json::from_str::(&completed_json) { + Ok(mut message) => { + if let Some(max_size) = max_content_size { + truncate_message_content_in_place(&mut message, max_size); + } + return Ok(message); + } + Err(e) => { + println!("[SESSION] Truncation fix didn't work: {}", e); + } + } + } + } + } + + Err(anyhow::anyhow!("Truncation fix failed")) +} + /// Attempt to truncate a JSON string by finding and truncating large text values fn truncate_json_string(json_str: &str, max_content_size: usize) -> String { // This is a heuristic approach - look for large text values in the JSON @@ -405,7 +791,10 @@ fn truncate_json_string(json_str: &str, max_content_size: usize) -> String { /// /// Returns default empty metadata if the file doesn't exist or has no metadata. pub fn read_metadata(session_file: &Path) -> Result { + println!("[SESSION] Reading metadata from: {:?}", session_file); + if !session_file.exists() { + println!("[SESSION] Session file doesn't exist, returning default metadata"); return Ok(SessionMetadata::default()); } @@ -415,16 +804,28 @@ pub fn read_metadata(session_file: &Path) -> Result { // Read just the first line if reader.read_line(&mut first_line)? > 0 { + println!("[SESSION] Read first line, attempting to parse as metadata..."); // Try to parse as metadata match serde_json::from_str::(&first_line) { - Ok(metadata) => Ok(metadata), - Err(_) => { + Ok(metadata) => { + println!( + "[SESSION] Successfully parsed metadata: description='{}'", + metadata.description + ); + Ok(metadata) + } + Err(e) => { // If the first line isn't metadata, return default + println!( + "[SESSION] First line is not valid metadata ({}), returning default", + e + ); Ok(SessionMetadata::default()) } } } else { // Empty file, return default + println!("[SESSION] File is empty, returning default metadata"); Ok(SessionMetadata::default()) } } @@ -438,7 +839,17 @@ pub async fn persist_messages( messages: &[Message], provider: Option>, ) -> Result<()> { - persist_messages_with_schedule_id(session_file, messages, provider, None).await + println!( + "[SESSION] persist_messages called with {} messages to: {:?}", + messages.len(), + session_file + ); + let result = persist_messages_with_schedule_id(session_file, messages, provider, None).await; + match &result { + Ok(_) => println!("[SESSION] persist_messages completed successfully"), + Err(e) => println!("[SESSION] persist_messages failed: {}", e), + } + result } /// Write messages to a session file with metadata, including an optional scheduled job ID @@ -477,28 +888,103 @@ pub async fn persist_messages_with_schedule_id( } } -/// Write messages to a session file with the provided metadata +/// Write messages to a session file with the provided metadata using atomic operations /// -/// Overwrites the file with metadata as the first line, followed by all messages in JSONL format. +/// This function uses atomic file operations to prevent corruption: +/// 1. Writes to a temporary file first +/// 2. Uses fs2 file locking to prevent concurrent writes +/// 3. Atomically moves the temp file to the final location +/// 4. Includes comprehensive error handling and recovery pub fn save_messages_with_metadata( session_file: &Path, metadata: &SessionMetadata, messages: &[Message], ) -> Result<()> { - let file = File::create(session_file).expect("The path specified does not exist"); - let mut writer = io::BufWriter::new(file); + use fs2::FileExt; + + println!( + "[SESSION] Starting to save {} messages to: {:?}", + messages.len(), + session_file + ); + + // Create a temporary file in the same directory to ensure atomic move + let temp_file = session_file.with_extension("tmp"); + println!("[SESSION] Using temporary file: {:?}", temp_file); + + // Ensure the parent directory exists + if let Some(parent) = session_file.parent() { + println!("[SESSION] Ensuring parent directory exists: {:?}", parent); + fs::create_dir_all(parent)?; + } - // Write metadata as the first line - serde_json::to_writer(&mut writer, &metadata)?; - writeln!(writer)?; + // Create and lock the temporary file + println!("[SESSION] Creating and locking temporary file..."); + let file = fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&temp_file) + .map_err(|e| anyhow::anyhow!("Failed to create temporary file {:?}: {}", temp_file, e))?; + + // Get an exclusive lock on the file + println!("[SESSION] Acquiring exclusive lock..."); + file.try_lock_exclusive() + .map_err(|e| anyhow::anyhow!("Failed to lock file: {}", e))?; + + // Write to temporary file + { + println!( + "[SESSION] Writing metadata and {} messages to temporary file...", + messages.len() + ); + let mut writer = io::BufWriter::new(&file); - // Write all messages - for message in messages { - serde_json::to_writer(&mut writer, &message)?; + // Write metadata as the first line + println!("[SESSION] Writing metadata as first line..."); + serde_json::to_writer(&mut writer, &metadata) + .map_err(|e| anyhow::anyhow!("Failed to serialize metadata: {}", e))?; writeln!(writer)?; + + // Write all messages + println!("[SESSION] Writing {} messages...", messages.len()); + for (i, message) in messages.iter().enumerate() { + serde_json::to_writer(&mut writer, &message) + .map_err(|e| anyhow::anyhow!("Failed to serialize message {}: {}", i, e))?; + writeln!(writer)?; + + if (i + 1) % 50 == 0 { + println!("[SESSION] Written {} messages so far...", i + 1); + } + } + + // Ensure all data is written to disk + println!("[SESSION] Flushing writer buffer..."); + writer.flush()?; } - writer.flush()?; + // Sync to ensure data is persisted + println!("[SESSION] Syncing data to disk..."); + file.sync_all()?; + + // Release the lock + println!("[SESSION] Releasing file lock..."); + fs2::FileExt::unlock(&file).map_err(|e| anyhow::anyhow!("Failed to unlock file: {}", e))?; + + // Atomically move the temporary file to the final location + println!("[SESSION] Atomically moving temp file to final location..."); + fs::rename(&temp_file, session_file).map_err(|e| { + // Clean up temp file on failure + println!("[SESSION] Failed to move temp file, cleaning up..."); + let _ = fs::remove_file(&temp_file); + anyhow::anyhow!("Failed to move temporary file to final location: {}", e) + })?; + + println!( + "[SESSION] Successfully saved session file: {:?}", + session_file + ); + tracing::debug!("Successfully saved session file: {:?}", session_file); Ok(()) } @@ -583,6 +1069,79 @@ mod tests { use crate::message::MessageContent; use tempfile::tempdir; + #[test] + fn test_corruption_recovery() -> Result<()> { + let test_cases = vec![ + // Case 1: Unclosed quotes + ( + r#"{"role":"user","content":[{"type":"text","text":"Hello there}]"#, + "Unclosed JSON with truncated content", + ), + // Case 2: Trailing comma + ( + r#"{"role":"user","content":[{"type":"text","text":"Test"},]}"#, + "JSON with trailing comma", + ), + // Case 3: Missing closing brace + ( + r#"{"role":"user","content":[{"type":"text","text":"Test""#, + "Incomplete JSON structure", + ), + // Case 4: Control characters in text + ( + r#"{"role":"user","content":[{"type":"text","text":"Test\u{0000}with\u{0001}control\u{0002}chars"}]}"#, + "JSON with control characters", + ), + // Case 5: Partial message with role and text + ( + r#"broken{"role": "assistant", "text": "This is recoverable content"more broken"#, + "Partial message with recoverable content", + ), + ]; + + println!("[TEST] Starting corruption recovery tests..."); + for (i, (corrupt_json, desc)) in test_cases.iter().enumerate() { + println!("\n[TEST] Case {}: {}", i + 1, desc); + println!( + "[TEST] Input: {}", + if corrupt_json.len() > 100 { + &corrupt_json[..100] + } else { + corrupt_json + } + ); + + // Try to parse the corrupted JSON + match attempt_corruption_recovery(corrupt_json, Some(50000)) { + Ok(message) => { + println!("[TEST] Successfully recovered message"); + // Verify we got some content + if let Some(MessageContent::Text(text_content)) = message.content.first() { + assert!( + !text_content.text.is_empty(), + "Recovered message should have content" + ); + println!( + "[TEST] Recovered content: {}", + if text_content.text.len() > 50 { + format!("{}...", &text_content.text[..50]) + } else { + text_content.text.clone() + } + ); + } + } + Err(e) => { + println!("[TEST] Failed to recover: {}", e); + panic!("Failed to recover from case {}: {}", i + 1, desc); + } + } + } + + println!("\n[TEST] All corruption recovery tests passed!"); + Ok(()) + } + #[tokio::test] async fn test_read_write_messages() -> Result<()> { let dir = tempdir()?;