Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 76 additions & 97 deletions crates/goose/src/providers/claude_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ struct CliProcess {
reader: BufReader<tokio::process::ChildStdout>,
#[allow(dead_code)]
stderr_handle: tokio::task::JoinHandle<String>,
messages_sent: usize,
}

impl Drop for CliProcess {
Expand All @@ -56,23 +55,15 @@ pub struct ClaudeCodeProvider {
}

impl ClaudeCodeProvider {
pub async fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
let command: String = config.get_claude_code_command().unwrap_or_default().into();
let resolved_command = SearchPaths::builder().with_npm().resolve(&command)?;

Ok(Self {
command: resolved_command,
model,
name: CLAUDE_CODE_PROVIDER_NAME.to_string(),
cli_process: tokio::sync::OnceCell::new(),
})
}

/// Build Anthropic content blocks from goose messages, supporting text and images.
fn messages_to_content_blocks(&self, messages: &[Message]) -> Vec<Value> {
/// Build content blocks from the last user message only — the CLI maintains
/// conversation context internally per session_id.
fn last_user_content_blocks(&self, messages: &[Message]) -> Vec<Value> {
let msgs = match messages.iter().rev().find(|m| m.role == Role::User) {
Some(msg) => std::slice::from_ref(msg),
None => messages,
};
let mut blocks: Vec<Value> = Vec::new();
for message in messages.iter().filter(|m| m.is_agent_visible()) {
for message in msgs.iter().filter(|m| m.is_agent_visible()) {
let prefix = match message.role {
Role::User => "Human: ",
Role::Assistant => "Assistant: ",
Expand Down Expand Up @@ -235,6 +226,9 @@ impl ClaudeCodeProvider {

// Combine all text content into a single message
let combined_text = all_text_content.join("\n\n");
if combined_text.contains("Prompt is too long") {
Comment thread
codefromthecrypt marked this conversation as resolved.
return Err(ProviderError::ContextLengthExceeded(combined_text));
}
if combined_text.is_empty() {
return Err(ProviderError::RequestFailed(
"No text content found in response".to_string(),
Expand All @@ -257,6 +251,7 @@ impl ClaudeCodeProvider {
system: &str,
messages: &[Message],
_tools: &[Tool],
session_id: &str,
) -> Result<Vec<String>, ProviderError> {
let filtered_system = filter_extensions_from_system_prompt(system);

Expand Down Expand Up @@ -331,25 +326,16 @@ impl ClaudeCodeProvider {
stdin,
reader: BufReader::new(stdout),
stderr_handle,
messages_sent: 0,
}))
})
.await?;

let mut process = process_mutex.lock().await;

// Build content from new messages only (skip already-sent ones).
// If messages is shorter than messages_sent, the caller started a fresh
// conversation on the same provider instance — send everything.
let new_messages = if process.messages_sent > 0 && process.messages_sent < messages.len() {
&messages[process.messages_sent..]
} else {
messages
};
let new_blocks = self.messages_to_content_blocks(new_messages);
let blocks = self.last_user_content_blocks(messages);

// Write NDJSON line to stdin
let ndjson_line = build_stream_json_input(&new_blocks);
let ndjson_line = build_stream_json_input(&blocks, session_id);
process
Comment thread
codefromthecrypt marked this conversation as resolved.
.stdin
.write_all(ndjson_line.as_bytes())
Expand Down Expand Up @@ -399,9 +385,6 @@ impl ClaudeCodeProvider {
}
}

// Update messages_sent for next turn
process.messages_sent = messages.len();

tracing::debug!("Command executed successfully, got {} lines", lines.len());
for (i, line) in lines.iter().enumerate() {
tracing::debug!("Line {}: {}", i, line);
Expand Down Expand Up @@ -456,8 +439,8 @@ impl ClaudeCodeProvider {
}
}

fn build_stream_json_input(content_blocks: &[Value]) -> String {
let msg = json!({"type":"user","message":{"role":"user","content":content_blocks}});
fn build_stream_json_input(content_blocks: &[Value], session_id: &str) -> String {
let msg = json!({"type":"user","session_id":session_id,"message":{"role":"user","content":content_blocks}});
serde_json::to_string(&msg).expect("serializing JSON content blocks cannot fail")
}

Expand All @@ -479,7 +462,18 @@ impl ProviderDef for ClaudeCodeProvider {
}

fn from_env(model: ModelConfig) -> BoxFuture<'static, Result<Self::Provider>> {
Box::pin(Self::from_env(model))
Box::pin(async move {
let config = crate::config::Config::global();
let command: String = config.get_claude_code_command().unwrap_or_default().into();
let resolved_command = SearchPaths::builder().with_npm().resolve(command)?;

Ok(Self {
command: resolved_command,
model,
name: CLAUDE_CODE_PROVIDER_NAME.to_string(),
cli_process: tokio::sync::OnceCell::new(),
})
})
}
}

Expand Down Expand Up @@ -507,7 +501,7 @@ impl Provider for ClaudeCodeProvider {
)]
async fn complete_with_model(
&self,
_session_id: Option<&str>, // create_session == YYYYMMDD_N, but --session-id requires a UUID
Copy link
Copy Markdown
Collaborator Author

@codefromthecrypt codefromthecrypt Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this only applies to CLI validation, not the ndjson itself.

session_id: Option<&str>,
model_config: &ModelConfig,
system: &str,
messages: &[Message],
Expand All @@ -518,7 +512,9 @@ impl Provider for ClaudeCodeProvider {
return self.generate_simple_session_description(messages);
}

let json_lines = self.execute_command(system, messages, tools).await?;
// session_id is None before a session is created (e.g. model listing).
let sid = session_id.unwrap_or("default");
let json_lines = self.execute_command(system, messages, tools, sid).await?;

let (message, usage) = self.parse_claude_response(&json_lines)?;

Expand Down Expand Up @@ -548,6 +544,7 @@ impl Provider for ClaudeCodeProvider {
#[cfg(test)]
mod tests {
use super::*;
use goose_test_support::session::TEST_SESSION_ID;
use serde_json::json;
use test_case::test_case;

Expand Down Expand Up @@ -576,98 +573,75 @@ mod tests {
}

#[test_case(
&[],
build_messages(&[]),
&[]
; "empty"
)]
#[test_case(
&[("user", "Hello", None)],
build_messages(&[("user", "Hello", None)]),
&[json!({"type":"text","text":"Human: Hello"})]
; "single_user"
)]
#[test_case(
&[("user", "Hello", None), ("assistant", "Hi there!", None)],
&[json!({"type":"text","text":"Human: Hello"}), json!({"type":"text","text":"Assistant: Hi there!"})]
; "user_and_assistant"
build_messages(&[("user", "Hello", None), ("assistant", "Hi there!", None)]),
&[json!({"type":"text","text":"Human: Hello"})]
; "picks_last_user_ignores_assistant"
)]
#[test_case(
&[("user", "Describe this", Some(("base64data", "image/png")))],
build_messages(&[("user", "First", None), ("assistant", "Reply", None), ("user", "Second", None)]),
&[json!({"type":"text","text":"Human: Second"})]
; "multi_turn_picks_last_user"
)]
#[test_case(
build_messages(&[("user", "Describe this", Some(("base64data", "image/png")))]),
&[json!({"type":"text","text":"Human: Describe this"}),
json!({"type":"image","source":{"type":"base64","media_type":"image/png","data":"base64data"}})]
; "user_with_image"
)]
#[test_case(
&[("user", "", Some(("iVBORw0KGgo", "image/png")))],
build_messages(&[("user", "", Some(("iVBORw0KGgo", "image/png")))]),
&[json!({"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBORw0KGgo"}})]
; "image_only"
)]
fn test_messages_to_content_blocks(pairs: &[MsgSpec], expected: &[Value]) {
#[test_case(
vec![Message::new(Role::Assistant, 0, vec![
MessageContent::tool_request("call_123", Ok(rmcp::model::CallToolRequestParams {
name: "developer__shell".into(),
arguments: Some(serde_json::from_value(json!({"cmd": "ls"})).unwrap()),
meta: None, task: None,
}))
])],
&[json!({"type":"text","text":"Assistant: [tool_use: developer__shell id=call_123]"})]
; "tool_request_no_user_fallback"
)]
#[test_case(
vec![Message::new(Role::User, 0, vec![
MessageContent::tool_response("call_123", Ok(rmcp::model::CallToolResult {
content: vec![rmcp::model::Content::text("file1.txt\nfile2.txt")],
is_error: None, structured_content: None, meta: None,
}))
])],
&[json!({"type":"text","text":"Human: [tool_result id=call_123] file1.txt\nfile2.txt"})]
; "tool_response"
)]
fn test_last_user_content_blocks(messages: Vec<Message>, expected: &[Value]) {
let provider = make_provider();
let messages = build_messages(pairs);
let blocks = provider.messages_to_content_blocks(&messages);
let blocks = provider.last_user_content_blocks(&messages);
assert_eq!(blocks, expected);
}

#[test]
fn test_messages_to_content_blocks_tool_request() {
use rmcp::model::CallToolRequestParams;
let provider = make_provider();
let tool_call = Ok(CallToolRequestParams {
name: "developer__shell".into(),
arguments: Some(serde_json::from_value(json!({"cmd": "ls"})).unwrap()),
meta: None,
task: None,
});
let msg = Message::new(
Role::Assistant,
0,
vec![MessageContent::tool_request("call_123", tool_call)],
);
let blocks = provider.messages_to_content_blocks(&[msg]);
assert_eq!(
blocks,
vec![
json!({"type":"text","text":"Assistant: [tool_use: developer__shell id=call_123]"})
]
);
}

#[test]
fn test_messages_to_content_blocks_tool_response() {
use rmcp::model::{CallToolResult, Content};
let provider = make_provider();
let result = CallToolResult {
content: vec![Content::text("file1.txt\nfile2.txt")],
is_error: None,
structured_content: None,
meta: None,
};
let msg = Message::new(
Role::User,
0,
vec![MessageContent::tool_response("call_123", Ok(result))],
);
let blocks = provider.messages_to_content_blocks(&[msg]);
assert_eq!(
blocks,
vec![
json!({"type":"text","text":"Human: [tool_result id=call_123] file1.txt\nfile2.txt"})
]
);
}

#[test_case(
&[json!({"type":"text","text":"Hello"})],
json!({"type":"user","message":{"role":"user","content":[{"type":"text","text":"Hello"}]}})
json!({"type":"user","session_id":TEST_SESSION_ID,"message":{"role":"user","content":[{"type":"text","text":"Hello"}]}})
; "text_block"
)]
#[test_case(
&[json!({"type":"text","text":"Look"}), json!({"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc"}})],
json!({"type":"user","message":{"role":"user","content":[{"type":"text","text":"Look"},{"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc"}}]}})
json!({"type":"user","session_id":TEST_SESSION_ID,"message":{"role":"user","content":[{"type":"text","text":"Look"},{"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc"}}]}})
; "text_and_image_blocks"
)]
fn test_build_stream_json_input(blocks: &[Value], expected: Value) {
let line = build_stream_json_input(blocks);
let line = build_stream_json_input(blocks, TEST_SESSION_ID);
let parsed: Value = serde_json::from_str(&line).unwrap();
assert_eq!(parsed, expected);
}
Expand Down Expand Up @@ -733,6 +707,11 @@ mod tests {
ProviderError::RequestFailed("Claude CLI error: Model not supported".into())
; "generic_error"
)]
#[test_case(
&[r#"{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"Prompt is too long"}]}}"#],
ProviderError::ContextLengthExceeded("Prompt is too long".into())
; "prompt_too_long_exact"
)]
fn test_parse_claude_response_err(lines: &[&str], expected: ProviderError) {
let provider = make_provider();
let lines: Vec<String> = lines.iter().map(|s| s.to_string()).collect();
Expand Down
78 changes: 38 additions & 40 deletions crates/goose/src/providers/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,45 +53,6 @@ pub struct CodexProvider {
}

impl CodexProvider {
pub async fn from_env(model: ModelConfig) -> Result<Self> {
let config = Config::global();
let command: String = config.get_codex_command().unwrap_or_default().into();
let resolved_command = SearchPaths::builder().with_npm().resolve(&command)?;

// Get reasoning effort from config, default to "high"
let reasoning_effort = config
.get_codex_reasoning_effort()
.map(String::from)
.unwrap_or_else(|_| "high".to_string());

// Validate reasoning effort
let reasoning_effort =
if Self::supports_reasoning_effort(&model.model_name, &reasoning_effort) {
reasoning_effort
} else {
tracing::warn!(
"Invalid CODEX_REASONING_EFFORT '{}' for model '{}', using 'high'",
reasoning_effort,
model.model_name
);
"high".to_string()
};

// Get skip_git_check from config, default to false
let skip_git_check = config
.get_codex_skip_git_check()
.map(|s| s.to_lowercase() == "true")
.unwrap_or(false);

Ok(Self {
command: resolved_command,
model,
name: CODEX_PROVIDER_NAME.to_string(),
reasoning_effort,
skip_git_check,
})
}

fn supports_reasoning_effort(model_name: &str, reasoning_effort: &str) -> bool {
if !CODEX_REASONING_LEVELS.contains(&reasoning_effort) {
return false;
Expand Down Expand Up @@ -600,7 +561,44 @@ impl ProviderDef for CodexProvider {
}

fn from_env(model: ModelConfig) -> BoxFuture<'static, Result<Self::Provider>> {
Box::pin(Self::from_env(model))
Box::pin(async move {
let config = Config::global();
let command: String = config.get_codex_command().unwrap_or_default().into();
let resolved_command = SearchPaths::builder().with_npm().resolve(command)?;

// Get reasoning effort from config, default to "high"
let reasoning_effort = config
.get_codex_reasoning_effort()
.map(String::from)
.unwrap_or_else(|_| "high".to_string());

// Validate reasoning effort
let reasoning_effort =
if Self::supports_reasoning_effort(&model.model_name, &reasoning_effort) {
reasoning_effort
} else {
tracing::warn!(
"Invalid CODEX_REASONING_EFFORT '{}' for model '{}', using 'high'",
reasoning_effort,
model.model_name
);
"high".to_string()
};

// Get skip_git_check from config, default to false
let skip_git_check = config
.get_codex_skip_git_check()
.map(|s| s.to_lowercase() == "true")
.unwrap_or(false);

Ok(Self {
command: resolved_command,
model,
name: CODEX_PROVIDER_NAME.to_string(),
reasoning_effort,
skip_git_check,
})
})
}
}

Expand Down
Loading
Loading