Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ do_not_version/
/.env

/working_dir

# Error log artifacts from mcp replay tests
crates/goose/tests/mcp_replays/*errors.txt
25 changes: 17 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,9 @@ win-total-rls *allparam:
just win-bld-rls{{allparam}}
just win-run-rls

build-test-tools:
cargo build -p goose-test

record-mcp-tests: build-test-tools
GOOSE_RECORD_MCP=1 cargo test --package goose --test mcp_integration_test
git add crates/goose/tests/mcp_replays/
19 changes: 19 additions & 0 deletions crates/goose-test/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "goose-test"
edition.workspace = true
version.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
description.workspace = true

[lints]
workspace = true

[[bin]]
name = "capture"
path = "src/bin/capture.rs"

[dependencies]
clap = { version = "4.5.44", features = ["derive"] }
serde_json = "1.0.142"
45 changes: 45 additions & 0 deletions crates/goose-test/src/bin/capture.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use std::io;

use clap::{Parser, Subcommand, ValueEnum};

use goose_test::mcp::stdio::playback::playback;
use goose_test::mcp::stdio::record::record;

#[derive(Parser)]
struct Cli {
#[arg(value_enum)]
transport: Transport,
#[command(subcommand)]
mode: Mode,
}

#[derive(ValueEnum, Clone, Debug)]
enum Transport {
Stdio,
}

#[derive(Subcommand, Clone, Debug)]
enum Mode {
Record {
file: String,
command: String,
#[arg(trailing_var_arg = true, allow_hyphen_values = true)]
args: Vec<String>,
},
Playback {
file: String,
},
}

fn main() -> io::Result<()> {
let cli = Cli::parse();

match cli.mode {
Mode::Record {
file,
command,
args,
} => record(&file, &command, &args),
Mode::Playback { file } => playback(&file),
}
}
1 change: 1 addition & 0 deletions crates/goose-test/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod mcp;
1 change: 1 addition & 0 deletions crates/goose-test/src/mcp/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod stdio;
2 changes: 2 additions & 0 deletions crates/goose-test/src/mcp/stdio/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod playback;
pub mod record;
94 changes: 94 additions & 0 deletions crates/goose-test/src/mcp/stdio/playback.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use std::fs::File;
use std::io::{self, BufRead, BufReader, Write};
use std::process;

use serde_json::Value;

#[derive(Debug, Clone)]
enum StreamType {
Stdin,
Stdout,
Stderr,
}

#[derive(Debug, Clone)]
struct LogEntry {
stream_type: StreamType,
content: String,
}

fn parse_log_line(line: &str) -> Option<LogEntry> {
line.find(": ").and_then(|pos| {
let (prefix, content) = line.split_at(pos);
let content = &content[2..]; // Skip ": "

let stream_type = match prefix {
"STDIN" => StreamType::Stdin,
"STDOUT" => StreamType::Stdout,
"STDERR" => StreamType::Stderr,
_ => return None,
};

Some(LogEntry {
stream_type,
content: content.to_string(),
})
})
}

fn load_log_file(file_path: &str) -> io::Result<Vec<LogEntry>> {
let file = File::open(file_path)?;
let reader = BufReader::new(file);
let mut entries = Vec::new();

for line in reader.lines() {
let line = line?;
if let Some(entry) = parse_log_line(&line) {
entries.push(entry);
}
}

Ok(entries)
}

pub fn playback(log_file_path: &String) -> io::Result<()> {
let entries = load_log_file(log_file_path)?;
let errors_file = File::create(format!("{}.errors.txt", log_file_path))?;

let stdin = io::stdin();
let mut stdout = io::stdout();
let mut stderr = io::stderr();

for entry in entries {
match entry.stream_type {
StreamType::Stdout => {
writeln!(stdout, "{}", entry.content)?;
stdout.flush()?;
}
StreamType::Stderr => {
writeln!(stderr, "{}", entry.content)?;
stderr.flush()?;
}
StreamType::Stdin => {
// Wait for matching input
let mut input = String::new();
stdin.read_line(&mut input)?;
input = input.trim_end_matches('\n').to_string();

let input_value: Value = serde_json::from_str::<Value>(&input)?;
let entry_value: Value = serde_json::from_str::<Value>(&entry.content)?;
if input_value != entry_value {
writeln!(
&errors_file,
"expected:\n{}\ngot:\n{}",
serde_json::to_string(&input_value)?,
serde_json::to_string(&entry_value)?
)?;
process::exit(1);
}
}
}
}

Ok(())
}
115 changes: 115 additions & 0 deletions crates/goose-test/src/mcp/stdio/record.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use std::fs::OpenOptions;
use std::io::{self, BufRead, BufReader, Write};
use std::process::{ChildStdin, Command, Stdio};
use std::sync::mpsc;
use std::thread::{self, JoinHandle};

#[derive(Debug, Clone)]
enum StreamType {
Stdin,
Stdout,
Stderr,
}

fn handle_output_stream<R: BufRead + Send + 'static>(
reader: R,
sender: mpsc::Sender<(StreamType, String)>,
stream_type: StreamType,
mut output_writer: Box<dyn Write + Send>,
) -> JoinHandle<()> {
thread::spawn(move || {
for line in reader.lines() {
match line {
Ok(line) => {
let _ = sender.send((stream_type.clone(), line.clone()));

if writeln!(output_writer, "{}", line).is_err() {
break;
}
}
Err(_) => break,
}
}
})
}

fn handle_stdin_stream(
mut child_stdin: ChildStdin,
sender: mpsc::Sender<(StreamType, String)>,
) -> JoinHandle<()> {
thread::spawn(move || {
let stdin = io::stdin();

for line in stdin.lock().lines() {
match line {
Ok(line) => {
let _ = sender.send((StreamType::Stdin, line.clone()));

if writeln!(child_stdin, "{}", line).is_err() {
break;
}
}
Err(_) => break,
}
}
})
}

pub fn record(log_file_path: &String, cmd: &String, cmd_args: &[String]) -> io::Result<()> {
let (tx, rx) = mpsc::channel();

let log_file = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(log_file_path)?;

let mut child = Command::new(cmd)
.args(cmd_args.iter())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.inspect_err(|e| eprintln!("Failed to execute command '{}': {}", &cmd, e))?;

let child_stdin = child.stdin.take().unwrap();
let child_stdout = child.stdout.take().unwrap();
let child_stderr = child.stderr.take().unwrap();

let stdin_handle = handle_stdin_stream(child_stdin, tx.clone());
let stdout_handle = handle_output_stream(
BufReader::new(child_stdout),
tx.clone(),
StreamType::Stdout,
Box::new(io::stdout()),
);
let stderr_handle = handle_output_stream(
BufReader::new(child_stderr),
tx.clone(),
StreamType::Stderr,
Box::new(io::stderr()),
);

thread::spawn(move || {
let mut log_file = log_file;
for (stream_type, line) in rx {
let prefix = match stream_type {
StreamType::Stdin => "STDIN",
StreamType::Stdout => "STDOUT",
StreamType::Stderr => "STDERR",
};
if let Err(e) = writeln!(log_file, "{}: {}", prefix, line) {
eprintln!("Error writing to log file: {}", e);
}
log_file.flush().ok();
}
});

child.wait()?;

stdin_handle.join().ok();
stdout_handle.join().ok();
stderr_handle.join().ok();

Ok(())
}
1 change: 1 addition & 0 deletions crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ tokio = { version = "1.43", features = ["full"] }
temp-env = "0.3.6"
dotenvy = "0.15.7"
ctor = "0.2.9"
test-case = "3.3"

[[example]]
name = "agent"
Expand Down
Loading
Loading