Skip to content

Commit

Permalink
feat: maintain backwards compatibility with conversastion.toml
Browse files Browse the repository at this point in the history
- Use conversation.toml for storing latest conversation state
- Remove unnecessary directory creation in config initialization
- Update test coverage
  • Loading branch information
bytesoverflow committed Nov 6, 2024
1 parent 6aad6f7 commit 9f91859
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 76 deletions.
12 changes: 2 additions & 10 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{path::PathBuf, process::Command};

use self::{
api::{api_keys_path, generate_api_keys_file, get_api_config},
prompt::{generate_prompts_file, get_prompts, prompts_path, conversations_path},
prompt::{generate_prompts_file, get_prompts, prompts_path},
};
use crate::utils::is_interactive;

Expand Down Expand Up @@ -58,12 +58,6 @@ pub fn ensure_config_files() -> std::io::Result<()> {
}
};

// Create the conversations directory if it doesn't exist
if !conversations_path().exists() {
std::fs::create_dir_all(conversations_path())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("Failed to create conversations directory: {}", e)))?;
}

Ok(())
}

Expand Down Expand Up @@ -113,7 +107,7 @@ mod tests {
config::{
api::{api_keys_path, default_timeout_seconds, Api, ApiConfig},
ensure_config_files,
prompt::{prompts_path, conversations_path, Prompt},
prompt::{prompts_path, Prompt},
resolve_config_path, CUSTOM_CONFIG_ENV_VAR, DEFAULT_CONFIG_PATH,
},
utils::IS_NONINTERACTIVE_ENV_VAR,
Expand Down Expand Up @@ -181,7 +175,6 @@ mod tests {

assert!(!api_keys_path.exists());
assert!(!prompts_path.exists());
assert!(!conversations_path().exists());

let result = ensure_config_files();

Expand All @@ -194,7 +187,6 @@ mod tests {

assert!(api_keys_path.exists());
assert!(prompts_path.exists());
assert!(conversations_path().exists());

Ok(())
}
Expand Down
184 changes: 147 additions & 37 deletions src/config/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use std::path::PathBuf;
use crate::config::{api::Api, resolve_config_path};

const PROMPT_FILE: &str = "prompts.toml";
const CONVERSATIONS_PATH: &str = "conversations/";
const CONVERSATION_FILE: &str = "conversation.toml";
const CONVERSATIONS_PATH: &str = "saved_conversations";

#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
pub struct Prompt {
Expand Down Expand Up @@ -96,31 +97,66 @@ pub(super) fn prompts_path() -> PathBuf {
resolve_config_path().join(PROMPT_FILE)
}

pub fn conversation_file_path() -> PathBuf {
resolve_config_path().join(CONVERSATION_FILE)
}

// Get the path to the conversations directory
pub fn conversations_path() -> PathBuf {
resolve_config_path().join(CONVERSATIONS_PATH)
}

// Get the path to a specific conversation file
pub fn conversation_file_path(name: &str) -> PathBuf {
pub fn named_conversation_path(name: &str) -> PathBuf {
conversations_path().join(format!("{}.toml", name))
}

// Get the last conversation as a prompt, if it exists
pub fn get_last_conversation_as_prompt(name: &str) -> Option<Prompt> {
let file_path = conversation_file_path(name);
if !file_path.exists() {
return None;
pub fn get_last_conversation_as_prompt(name: Option<&str>) -> Option<Prompt> {
if let Some(name) = name {
let named_path = named_conversation_path(name);
if !named_path.exists() {
return None;
}
let content = fs::read_to_string(named_path)
.unwrap_or_else(|error| {
panic!(
"Could not read file {:?}, {:?}",
named_conversation_path(name),
error
)
});
Some(toml::from_str(&content).expect("failed to load the conversation file"))
} else {
let path = conversation_file_path();
if !path.exists() {
return None;
}
let content = fs::read_to_string(path)
.unwrap_or_else(|error| {
panic!(
"Could not read file {:?}, {:?}",
conversation_file_path(),
error
)
});
Some(toml::from_str(&content).expect("failed to load the conversation file"))
}
}

pub fn save_conversation(prompt: &Prompt, name: Option<&str>) -> std::io::Result<()> {
let toml_string = toml::to_string(prompt).expect("Failed to serialize prompt");

// Always save to conversation.toml
fs::write(conversation_file_path(), &toml_string)?;

// If name is provided, also save to named conversation file
if let Some(name) = name {
fs::create_dir_all(conversations_path())?;
fs::write(named_conversation_path(name), &toml_string)?;
}

let content = fs::read_to_string(file_path).unwrap_or_else(|error| {
panic!(
"Could not read file {:?}, {:?}",
conversation_file_path(name),
error
)
});
Some(toml::from_str(&content).expect("failed to load the conversation file"))
Ok(())
}

pub(super) fn generate_prompts_file() -> std::io::Result<()> {
Expand Down Expand Up @@ -153,40 +189,114 @@ pub fn get_prompts() -> HashMap<String, Prompt> {
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
use crate::config::prompt::Prompt;
use serial_test::serial;

fn setup() -> tempfile::TempDir {
let temp_dir = tempdir().unwrap();
std::env::set_var("SMARTCAT_CONFIG_PATH", temp_dir.path());
temp_dir
}

fn create_test_prompt() -> Prompt {
let mut prompt = Prompt::default();
prompt.messages = vec![(Message::user("test"))];
prompt
}

#[test]
fn test_conversation_file_path() {
let name = "test_conversation";
let file_path = conversation_file_path(name);
assert_eq!(
file_path.file_name().unwrap().to_str().unwrap(),
format!("{}.toml", name)
);
assert_eq!(file_path.parent().unwrap(), conversations_path());
#[serial]
fn test_get_and_save_default_conversation() {
let _temp_dir = setup();
let test_prompt = create_test_prompt();

// Test saving conversation
save_conversation(&test_prompt, None).unwrap();
assert!(conversation_file_path().exists());

// Test retrieving conversation
let loaded_prompt = get_last_conversation_as_prompt(None).unwrap();
assert_eq!(loaded_prompt, test_prompt);
}

#[test]
fn test_get_last_conversation_as_prompt() {
let name = "test_conversation";
let file_path = conversation_file_path(name);
let prompt = Prompt::default();
#[serial]
fn test_get_and_save_named_conversation() {
let _temp_dir = setup();
let test_prompt = create_test_prompt();
let conv_name = "test_conversation";

// Create a test conversation file
let toml_string = toml::to_string(&prompt).expect("Failed to serialize prompt");
fs::write(&file_path, toml_string).expect("Failed to write test conversation file");
// Test saving named conversation
save_conversation(&test_prompt, Some(conv_name)).unwrap();
assert!(named_conversation_path(conv_name).exists());
assert!(conversation_file_path().exists()); // Should also save to default location

let loaded_prompt = get_last_conversation_as_prompt(name);
assert_eq!(loaded_prompt, Some(prompt));
// Test retrieving named conversation
let loaded_prompt = get_last_conversation_as_prompt(Some(conv_name)).unwrap();
assert_eq!(loaded_prompt, test_prompt);
}

// Clean up the test conversation file
fs::remove_file(&file_path).expect("Failed to remove test conversation file");
#[test]
#[serial]
fn test_nonexistent_conversation() {
let _temp_dir = setup();

// Test getting nonexistent default conversation
assert!(get_last_conversation_as_prompt(None).is_none());

// Test getting nonexistent named conversation
assert!(get_last_conversation_as_prompt(Some("nonexistent")).is_none());
}

#[test]
#[serial]
fn test_conversation_file_contents() {
let _temp_dir = setup();
let test_prompt = create_test_prompt();
let conv_name = "test_conversation";

// Save conversation
save_conversation(&test_prompt, Some(conv_name)).unwrap();

// Verify default and named files have identical content
let default_content = fs::read_to_string(conversation_file_path()).unwrap();
let named_content = fs::read_to_string(named_conversation_path(conv_name)).unwrap();
assert_eq!(default_content, named_content);

// Verify content can be parsed back to original prompt
let parsed_prompt: Prompt = toml::from_str(&default_content).unwrap();
assert_eq!(parsed_prompt, test_prompt);
}

#[test]
fn test_get_last_conversation_as_prompt_missing_file() {
let name = "nonexistent_conversation";
let loaded_prompt = get_last_conversation_as_prompt(name);
assert_eq!(loaded_prompt, None);
#[serial]
fn test_generate_prompts_file() {
let _temp_dir = setup();

// Test file generation
generate_prompts_file().unwrap();
assert!(prompts_path().exists());

// Verify file is valid TOML and contains expected content
let content = fs::read_to_string(prompts_path()).unwrap();
let prompts: HashMap<String, Prompt> = toml::from_str(&content).unwrap();
assert!(!prompts.is_empty());
}

#[test]
#[serial]
fn test_get_prompts() {
let _temp_dir = setup();

// Generate prompts file
generate_prompts_file().unwrap();

// Test loading prompts
let prompts = get_prompts();
assert!(!prompts.is_empty());

// Verify at least one default prompt exists
assert!(prompts.contains_key("default"));
}
}
Loading

0 comments on commit 9f91859

Please sign in to comment.