Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions crates/goose/src/config/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use serde_json::Value;
use std::collections::HashMap;
use std::env;
use std::fs::OpenOptions;
use std::io::Write;
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use thiserror::Error;

Expand Down Expand Up @@ -342,8 +342,15 @@ impl Config {
/// - There is an error reading or writing the config file
/// - There is an error serializing the value
pub fn set_param(&self, key: &str, value: Value) -> Result<(), ConfigError> {
// Open the file with write permissions, create if it doesn't exist
// Ensure the directory exists
if let Some(parent) = self.config_path.parent() {
std::fs::create_dir_all(parent)
.map_err(|e| ConfigError::DirectoryError(e.to_string()))?;
}

// Open the file with read+write permissions, create if it doesn't exist
let mut file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
Expand All @@ -353,17 +360,22 @@ impl Config {
file.lock_exclusive()
.map_err(|e| ConfigError::LockError(e.to_string()))?;

// Load current values while holding the lock
let mut values = if self.config_path.exists() {
let file_content = std::fs::read_to_string(&self.config_path)?;
let yaml_value: serde_yaml::Value = serde_yaml::from_str(&file_content)?;
let json_value: Value = serde_json::to_value(yaml_value)?;
match json_value {
Value::Object(map) => map.into_iter().collect(),
_ => HashMap::new(),
// Load current values while holding the lock - read through the file handle
let mut values = {
let mut file_content = String::new();
file.seek(SeekFrom::Start(0))?;
file.read_to_string(&mut file_content)?;

if file_content.trim().is_empty() {
HashMap::new()
} else {
let yaml_value: serde_yaml::Value = serde_yaml::from_str(&file_content)?;
let json_value: Value = serde_json::to_value(yaml_value)?;
match json_value {
Value::Object(map) => map.into_iter().collect(),
_ => HashMap::new(),
}
}
} else {
HashMap::new()
};

// Modify values
Expand All @@ -373,6 +385,7 @@ impl Config {
let yaml_value = serde_yaml::to_string(&values)?;

// Write the contents using the same file handle
file.seek(SeekFrom::Start(0))?; // Seek to beginning before writing
file.set_len(0)?; // Clear the file
file.write_all(yaml_value.as_bytes())?;
file.sync_all()?;
Expand Down
Loading