diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index 94f41aea2d27..ebc89ae0309f 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -31,7 +31,6 @@ struct CliProcess { reader: BufReader, #[allow(dead_code)] stderr_handle: tokio::task::JoinHandle, - messages_sent: usize, } impl Drop for CliProcess { @@ -56,23 +55,15 @@ pub struct ClaudeCodeProvider { } impl ClaudeCodeProvider { - pub async fn from_env(model: ModelConfig) -> Result { - let config = crate::config::Config::global(); - let command: String = config.get_claude_code_command().unwrap_or_default().into(); - let resolved_command = SearchPaths::builder().with_npm().resolve(&command)?; - - Ok(Self { - command: resolved_command, - model, - name: CLAUDE_CODE_PROVIDER_NAME.to_string(), - cli_process: tokio::sync::OnceCell::new(), - }) - } - - /// Build Anthropic content blocks from goose messages, supporting text and images. - fn messages_to_content_blocks(&self, messages: &[Message]) -> Vec { + /// Build content blocks from the last user message only — the CLI maintains + /// conversation context internally per session_id. + fn last_user_content_blocks(&self, messages: &[Message]) -> Vec { + let msgs = match messages.iter().rev().find(|m| m.role == Role::User) { + Some(msg) => std::slice::from_ref(msg), + None => messages, + }; let mut blocks: Vec = Vec::new(); - for message in messages.iter().filter(|m| m.is_agent_visible()) { + for message in msgs.iter().filter(|m| m.is_agent_visible()) { let prefix = match message.role { Role::User => "Human: ", Role::Assistant => "Assistant: ", @@ -235,6 +226,9 @@ impl ClaudeCodeProvider { // Combine all text content into a single message let combined_text = all_text_content.join("\n\n"); + if combined_text.contains("Prompt is too long") { + return Err(ProviderError::ContextLengthExceeded(combined_text)); + } if combined_text.is_empty() { return Err(ProviderError::RequestFailed( "No text content found in response".to_string(), @@ -257,6 +251,7 @@ impl ClaudeCodeProvider { system: &str, messages: &[Message], _tools: &[Tool], + session_id: &str, ) -> Result, ProviderError> { let filtered_system = filter_extensions_from_system_prompt(system); @@ -331,25 +326,16 @@ impl ClaudeCodeProvider { stdin, reader: BufReader::new(stdout), stderr_handle, - messages_sent: 0, })) }) .await?; let mut process = process_mutex.lock().await; - // Build content from new messages only (skip already-sent ones). - // If messages is shorter than messages_sent, the caller started a fresh - // conversation on the same provider instance — send everything. - let new_messages = if process.messages_sent > 0 && process.messages_sent < messages.len() { - &messages[process.messages_sent..] - } else { - messages - }; - let new_blocks = self.messages_to_content_blocks(new_messages); + let blocks = self.last_user_content_blocks(messages); // Write NDJSON line to stdin - let ndjson_line = build_stream_json_input(&new_blocks); + let ndjson_line = build_stream_json_input(&blocks, session_id); process .stdin .write_all(ndjson_line.as_bytes()) @@ -399,9 +385,6 @@ impl ClaudeCodeProvider { } } - // Update messages_sent for next turn - process.messages_sent = messages.len(); - tracing::debug!("Command executed successfully, got {} lines", lines.len()); for (i, line) in lines.iter().enumerate() { tracing::debug!("Line {}: {}", i, line); @@ -456,8 +439,8 @@ impl ClaudeCodeProvider { } } -fn build_stream_json_input(content_blocks: &[Value]) -> String { - let msg = json!({"type":"user","message":{"role":"user","content":content_blocks}}); +fn build_stream_json_input(content_blocks: &[Value], session_id: &str) -> String { + let msg = json!({"type":"user","session_id":session_id,"message":{"role":"user","content":content_blocks}}); serde_json::to_string(&msg).expect("serializing JSON content blocks cannot fail") } @@ -479,7 +462,18 @@ impl ProviderDef for ClaudeCodeProvider { } fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { - Box::pin(Self::from_env(model)) + Box::pin(async move { + let config = crate::config::Config::global(); + let command: String = config.get_claude_code_command().unwrap_or_default().into(); + let resolved_command = SearchPaths::builder().with_npm().resolve(command)?; + + Ok(Self { + command: resolved_command, + model, + name: CLAUDE_CODE_PROVIDER_NAME.to_string(), + cli_process: tokio::sync::OnceCell::new(), + }) + }) } } @@ -507,7 +501,7 @@ impl Provider for ClaudeCodeProvider { )] async fn complete_with_model( &self, - _session_id: Option<&str>, // create_session == YYYYMMDD_N, but --session-id requires a UUID + session_id: Option<&str>, model_config: &ModelConfig, system: &str, messages: &[Message], @@ -518,7 +512,9 @@ impl Provider for ClaudeCodeProvider { return self.generate_simple_session_description(messages); } - let json_lines = self.execute_command(system, messages, tools).await?; + // session_id is None before a session is created (e.g. model listing). + let sid = session_id.unwrap_or("default"); + let json_lines = self.execute_command(system, messages, tools, sid).await?; let (message, usage) = self.parse_claude_response(&json_lines)?; @@ -548,6 +544,7 @@ impl Provider for ClaudeCodeProvider { #[cfg(test)] mod tests { use super::*; + use goose_test_support::session::TEST_SESSION_ID; use serde_json::json; use test_case::test_case; @@ -576,98 +573,75 @@ mod tests { } #[test_case( - &[], + build_messages(&[]), &[] ; "empty" )] #[test_case( - &[("user", "Hello", None)], + build_messages(&[("user", "Hello", None)]), &[json!({"type":"text","text":"Human: Hello"})] ; "single_user" )] #[test_case( - &[("user", "Hello", None), ("assistant", "Hi there!", None)], - &[json!({"type":"text","text":"Human: Hello"}), json!({"type":"text","text":"Assistant: Hi there!"})] - ; "user_and_assistant" + build_messages(&[("user", "Hello", None), ("assistant", "Hi there!", None)]), + &[json!({"type":"text","text":"Human: Hello"})] + ; "picks_last_user_ignores_assistant" )] #[test_case( - &[("user", "Describe this", Some(("base64data", "image/png")))], + build_messages(&[("user", "First", None), ("assistant", "Reply", None), ("user", "Second", None)]), + &[json!({"type":"text","text":"Human: Second"})] + ; "multi_turn_picks_last_user" + )] + #[test_case( + build_messages(&[("user", "Describe this", Some(("base64data", "image/png")))]), &[json!({"type":"text","text":"Human: Describe this"}), json!({"type":"image","source":{"type":"base64","media_type":"image/png","data":"base64data"}})] ; "user_with_image" )] #[test_case( - &[("user", "", Some(("iVBORw0KGgo", "image/png")))], + build_messages(&[("user", "", Some(("iVBORw0KGgo", "image/png")))]), &[json!({"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBORw0KGgo"}})] ; "image_only" )] - fn test_messages_to_content_blocks(pairs: &[MsgSpec], expected: &[Value]) { + #[test_case( + vec![Message::new(Role::Assistant, 0, vec![ + MessageContent::tool_request("call_123", Ok(rmcp::model::CallToolRequestParams { + name: "developer__shell".into(), + arguments: Some(serde_json::from_value(json!({"cmd": "ls"})).unwrap()), + meta: None, task: None, + })) + ])], + &[json!({"type":"text","text":"Assistant: [tool_use: developer__shell id=call_123]"})] + ; "tool_request_no_user_fallback" + )] + #[test_case( + vec![Message::new(Role::User, 0, vec![ + MessageContent::tool_response("call_123", Ok(rmcp::model::CallToolResult { + content: vec![rmcp::model::Content::text("file1.txt\nfile2.txt")], + is_error: None, structured_content: None, meta: None, + })) + ])], + &[json!({"type":"text","text":"Human: [tool_result id=call_123] file1.txt\nfile2.txt"})] + ; "tool_response" + )] + fn test_last_user_content_blocks(messages: Vec, expected: &[Value]) { let provider = make_provider(); - let messages = build_messages(pairs); - let blocks = provider.messages_to_content_blocks(&messages); + let blocks = provider.last_user_content_blocks(&messages); assert_eq!(blocks, expected); } - #[test] - fn test_messages_to_content_blocks_tool_request() { - use rmcp::model::CallToolRequestParams; - let provider = make_provider(); - let tool_call = Ok(CallToolRequestParams { - name: "developer__shell".into(), - arguments: Some(serde_json::from_value(json!({"cmd": "ls"})).unwrap()), - meta: None, - task: None, - }); - let msg = Message::new( - Role::Assistant, - 0, - vec![MessageContent::tool_request("call_123", tool_call)], - ); - let blocks = provider.messages_to_content_blocks(&[msg]); - assert_eq!( - blocks, - vec![ - json!({"type":"text","text":"Assistant: [tool_use: developer__shell id=call_123]"}) - ] - ); - } - - #[test] - fn test_messages_to_content_blocks_tool_response() { - use rmcp::model::{CallToolResult, Content}; - let provider = make_provider(); - let result = CallToolResult { - content: vec![Content::text("file1.txt\nfile2.txt")], - is_error: None, - structured_content: None, - meta: None, - }; - let msg = Message::new( - Role::User, - 0, - vec![MessageContent::tool_response("call_123", Ok(result))], - ); - let blocks = provider.messages_to_content_blocks(&[msg]); - assert_eq!( - blocks, - vec![ - json!({"type":"text","text":"Human: [tool_result id=call_123] file1.txt\nfile2.txt"}) - ] - ); - } - #[test_case( &[json!({"type":"text","text":"Hello"})], - json!({"type":"user","message":{"role":"user","content":[{"type":"text","text":"Hello"}]}}) + json!({"type":"user","session_id":TEST_SESSION_ID,"message":{"role":"user","content":[{"type":"text","text":"Hello"}]}}) ; "text_block" )] #[test_case( &[json!({"type":"text","text":"Look"}), json!({"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc"}})], - json!({"type":"user","message":{"role":"user","content":[{"type":"text","text":"Look"},{"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc"}}]}}) + json!({"type":"user","session_id":TEST_SESSION_ID,"message":{"role":"user","content":[{"type":"text","text":"Look"},{"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc"}}]}}) ; "text_and_image_blocks" )] fn test_build_stream_json_input(blocks: &[Value], expected: Value) { - let line = build_stream_json_input(blocks); + let line = build_stream_json_input(blocks, TEST_SESSION_ID); let parsed: Value = serde_json::from_str(&line).unwrap(); assert_eq!(parsed, expected); } @@ -733,6 +707,11 @@ mod tests { ProviderError::RequestFailed("Claude CLI error: Model not supported".into()) ; "generic_error" )] + #[test_case( + &[r#"{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"Prompt is too long"}]}}"#], + ProviderError::ContextLengthExceeded("Prompt is too long".into()) + ; "prompt_too_long_exact" + )] fn test_parse_claude_response_err(lines: &[&str], expected: ProviderError) { let provider = make_provider(); let lines: Vec = lines.iter().map(|s| s.to_string()).collect(); diff --git a/crates/goose/src/providers/codex.rs b/crates/goose/src/providers/codex.rs index 81e47b2efe78..8f9a3e119477 100644 --- a/crates/goose/src/providers/codex.rs +++ b/crates/goose/src/providers/codex.rs @@ -53,45 +53,6 @@ pub struct CodexProvider { } impl CodexProvider { - pub async fn from_env(model: ModelConfig) -> Result { - let config = Config::global(); - let command: String = config.get_codex_command().unwrap_or_default().into(); - let resolved_command = SearchPaths::builder().with_npm().resolve(&command)?; - - // Get reasoning effort from config, default to "high" - let reasoning_effort = config - .get_codex_reasoning_effort() - .map(String::from) - .unwrap_or_else(|_| "high".to_string()); - - // Validate reasoning effort - let reasoning_effort = - if Self::supports_reasoning_effort(&model.model_name, &reasoning_effort) { - reasoning_effort - } else { - tracing::warn!( - "Invalid CODEX_REASONING_EFFORT '{}' for model '{}', using 'high'", - reasoning_effort, - model.model_name - ); - "high".to_string() - }; - - // Get skip_git_check from config, default to false - let skip_git_check = config - .get_codex_skip_git_check() - .map(|s| s.to_lowercase() == "true") - .unwrap_or(false); - - Ok(Self { - command: resolved_command, - model, - name: CODEX_PROVIDER_NAME.to_string(), - reasoning_effort, - skip_git_check, - }) - } - fn supports_reasoning_effort(model_name: &str, reasoning_effort: &str) -> bool { if !CODEX_REASONING_LEVELS.contains(&reasoning_effort) { return false; @@ -600,7 +561,44 @@ impl ProviderDef for CodexProvider { } fn from_env(model: ModelConfig) -> BoxFuture<'static, Result> { - Box::pin(Self::from_env(model)) + Box::pin(async move { + let config = Config::global(); + let command: String = config.get_codex_command().unwrap_or_default().into(); + let resolved_command = SearchPaths::builder().with_npm().resolve(command)?; + + // Get reasoning effort from config, default to "high" + let reasoning_effort = config + .get_codex_reasoning_effort() + .map(String::from) + .unwrap_or_else(|_| "high".to_string()); + + // Validate reasoning effort + let reasoning_effort = + if Self::supports_reasoning_effort(&model.model_name, &reasoning_effort) { + reasoning_effort + } else { + tracing::warn!( + "Invalid CODEX_REASONING_EFFORT '{}' for model '{}', using 'high'", + reasoning_effort, + model.model_name + ); + "high".to_string() + }; + + // Get skip_git_check from config, default to false + let skip_git_check = config + .get_codex_skip_git_check() + .map(|s| s.to_lowercase() == "true") + .unwrap_or(false); + + Ok(Self { + command: resolved_command, + model, + name: CODEX_PROVIDER_NAME.to_string(), + reasoning_effort, + skip_git_check, + }) + }) } } diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 9be54de9d0dc..a077a71e53b3 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -7,6 +7,8 @@ use goose::providers::anthropic::ANTHROPIC_DEFAULT_MODEL; use goose::providers::azure::AZURE_DEFAULT_MODEL; use goose::providers::base::Provider; use goose::providers::bedrock::BEDROCK_DEFAULT_MODEL; +use goose::providers::claude_code::CLAUDE_CODE_DEFAULT_MODEL; +use goose::providers::codex::CODEX_DEFAULT_MODEL; use goose::providers::create_with_named_model; use goose::providers::databricks::DATABRICKS_DEFAULT_MODEL; use goose::providers::errors::ProviderError; @@ -89,6 +91,7 @@ struct ProviderTester { provider: Arc, name: String, extension_manager: Arc, + is_cli_provider: bool, } impl ProviderTester { @@ -96,18 +99,20 @@ impl ProviderTester { provider: Arc, name: String, extension_manager: Arc, + is_cli_provider: bool, ) -> Self { Self { provider, name, extension_manager, + is_cli_provider, } } - async fn tool_roundtrip(&self, prompt: &str) -> Result { + async fn tool_roundtrip(&self, prompt: &str, session_id: &str) -> Result { let tools = self .extension_manager - .get_prefixed_tools(TEST_SESSION_ID, None) + .get_prefixed_tools(session_id, None) .await .expect("get_prefixed_tools failed"); @@ -120,20 +125,22 @@ impl ProviderTester { let message = Message::user().with_text(prompt); let (response1, _) = self .provider - .complete( - TEST_SESSION_ID, - &system, - std::slice::from_ref(&message), - &tools, - ) + .complete(session_id, &system, std::slice::from_ref(&message), &tools) .await?; + // Agentic CLI providers (claude-code, codex) call tools internally and + // return the final text result directly — no tool_request in the response. let tool_req = response1 .content .iter() .filter_map(|c| c.as_tool_request()) - .next_back() - .expect("Expected provider to return a tool request"); + .next_back(); + + let tool_req = match tool_req { + Some(req) => req, + None => return Ok(response1), + }; + let params = tool_req .tool_call .as_ref() @@ -141,7 +148,7 @@ impl ProviderTester { .clone(); let result = self .extension_manager - .dispatch_tool_call(TEST_SESSION_ID, params, None, CancellationToken::new()) + .dispatch_tool_call(session_id, params, None, CancellationToken::new()) .await .expect("dispatch failed") .result @@ -152,7 +159,7 @@ impl ProviderTester { let (response2, _) = self .provider .complete( - TEST_SESSION_ID, + session_id, &system, &[message, response1, tool_response], &tools, @@ -161,17 +168,12 @@ impl ProviderTester { Ok(response2) } - async fn test_basic_response(&self) -> Result<()> { + async fn test_basic_response(&self, session_id: &str) -> Result<()> { let message = Message::user().with_text("Just say hello!"); let (response, _) = self .provider - .complete( - TEST_SESSION_ID, - "You are a helpful assistant.", - &[message], - &[], - ) + .complete(session_id, "You are a helpful assistant.", &[message], &[]) .await?; assert_eq!( @@ -185,21 +187,33 @@ impl ProviderTester { "Expected text response" ); + println!( + "=== {}::basic_response === {}", + self.name, + response.as_concat_text() + ); Ok(()) } - async fn test_tool_usage(&self) -> Result<()> { + async fn test_tool_usage(&self, session_id: &str) -> Result<()> { let response = self - .tool_roundtrip("Use the get_code tool and output only its result.") + .tool_roundtrip( + "Use the get_code tool and output only its result.", + session_id, + ) .await?; + let text = response.as_concat_text(); assert!( - response.as_concat_text().contains(FAKE_CODE), - "Expected lookup code in final response" + text.contains(FAKE_CODE), + "Expected lookup code '{}' in final response, got: {}", + FAKE_CODE, + text ); + println!("=== {}::tool_usage === {}", self.name, text); Ok(()) } - async fn test_context_length_exceeded_error(&self) -> Result<()> { + async fn test_context_length_exceeded_error(&self, session_id: &str) -> Result<()> { let large_message_content = if self.name.to_lowercase() == "google" { "hello ".repeat(1_300_000) } else { @@ -220,19 +234,17 @@ impl ProviderTester { let result = self .provider - .complete( - TEST_SESSION_ID, - "You are a helpful assistant.", - &messages, - &[], - ) + .complete(session_id, "You are a helpful assistant.", &messages, &[]) .await; println!("=== {}::context_length_exceeded_error ===", self.name); dbg!(&result); println!("==================="); - if self.name.to_lowercase() == "ollama" || self.name.to_lowercase() == "openrouter" { + let name_lower = self.name.to_lowercase(); + if name_lower == "ollama" || name_lower == "openrouter" { + // These providers handle context overflow internally: ollama and + // openrouter truncate or have large windows. assert!( result.is_ok(), "Expected to succeed because of default truncation or large context window" @@ -252,9 +264,12 @@ impl ProviderTester { Ok(()) } - async fn test_image_content_support(&self) -> Result<()> { + async fn test_image_content_support(&self, session_id: &str) -> Result<()> { let response = self - .tool_roundtrip("Use the get_image tool and describe what you see in its result.") + .tool_roundtrip( + "Use the get_image tool and describe what you see in its result.", + session_id, + ) .await?; let text = response.as_concat_text().to_lowercase(); assert!( @@ -262,6 +277,7 @@ impl ProviderTester { "Expected response to describe the test image, got: {}", text ); + println!("=== {}::image_content === {}", self.name, text); Ok(()) } @@ -274,24 +290,40 @@ impl ProviderTester { assert!(!models.is_empty(), "Expected non-empty model list"); let model_name = &self.provider.get_model_config().model_name; - // Some providers (e.g. Ollama) return names with tags like "qwen3:latest" - // while the configured model name may be just "qwen3". + // Model names may not match exactly: Ollama adds tags like "qwen3:latest", + // and CLI providers like claude-code use aliases (e.g. "sonnet") that are + // substrings of full model names (e.g. "claude-sonnet-4-5-20250929"). assert!( models .iter() - .any(|m| m == model_name || m.starts_with(&format!("{}:", model_name))), + .any(|m| m == model_name || m.contains(model_name) || model_name.contains(m)), "Expected model '{}' in supported models", model_name ); Ok(()) } + fn session_id_for_test(&self, test_name: &str) -> String { + if self.is_cli_provider { + format!("test_{}", test_name) + } else { + TEST_SESSION_ID.to_string() + } + } + async fn run_test_suite(&self) -> Result<()> { self.test_model_listing().await?; - self.test_basic_response().await?; - self.test_tool_usage().await?; - self.test_context_length_exceeded_error().await?; - self.test_image_content_support().await?; + self.test_basic_response(&self.session_id_for_test("basic_response")) + .await?; + // TODO: remove skip in https://github.com/block/goose/pull/6972 + if !self.is_cli_provider { + self.test_tool_usage(&self.session_id_for_test("tool_usage")) + .await?; + self.test_image_content_support(&self.session_id_for_test("image_content")) + .await?; + } + self.test_context_length_exceeded_error(&self.session_id_for_test("context_length")) + .await?; Ok(()) } } @@ -307,6 +339,8 @@ async fn test_provider( model_name: &str, required_vars: &[&str], env_modifications: Option>>, + // CLI providers cannot propagate the agent-session-id header to MCP servers. + is_cli_provider: bool, ) -> Result<()> { TEST_REPORT.record_fail(name); @@ -350,10 +384,19 @@ async fn test_provider( original_env }; - let expected_session_id = ExpectedSessionId::default(); let provider_name = name.to_lowercase(); - let mcp = McpFixture::new(Some(expected_session_id.clone())).await; - expected_session_id.set(TEST_SESSION_ID); + let expected_session_id = if is_cli_provider { + None + } else { + Some(ExpectedSessionId::default()) + }; + let mcp = McpFixture::new(expected_session_id.clone()).await; + if let Some(ref id) = expected_session_id { + id.set(TEST_SESSION_ID); + } + + let mcp_extension = + ExtensionConfig::streamable_http("mcp-fixture", &mcp.url, "MCP fixture", 30_u64); let provider = match create_with_named_model(&provider_name, model_name).await { Ok(p) => p, @@ -383,16 +426,16 @@ async fn test_provider( let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf())); let extension_manager = Arc::new(ExtensionManager::new(shared_provider, session_manager)); extension_manager - .add_extension( - ExtensionConfig::streamable_http("mcp-fixture", &mcp.url, "MCP fixture", 30_u64), - None, - None, - None, - ) + .add_extension(mcp_extension, None, None, None) .await .expect("failed to add extension"); - let tester = ProviderTester::new(provider, name.to_string(), extension_manager); + let tester = ProviderTester::new( + provider, + name.to_string(), + extension_manager, + is_cli_provider, + ); let _mcp = mcp; let result = tester.run_test_suite().await; @@ -411,7 +454,14 @@ async fn test_provider( #[tokio::test] async fn test_openai_provider() -> Result<()> { - test_provider("openai", OPEN_AI_DEFAULT_MODEL, &["OPENAI_API_KEY"], None).await + test_provider( + "openai", + OPEN_AI_DEFAULT_MODEL, + &["OPENAI_API_KEY"], + None, + false, + ) + .await } #[tokio::test] @@ -425,6 +475,7 @@ async fn test_azure_provider() -> Result<()> { "AZURE_OPENAI_DEPLOYMENT_NAME", ], None, + false, ) .await } @@ -436,6 +487,7 @@ async fn test_bedrock_provider_long_term_credentials() -> Result<()> { BEDROCK_DEFAULT_MODEL, &["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], None, + false, ) .await } @@ -450,6 +502,7 @@ async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> { BEDROCK_DEFAULT_MODEL, &["AWS_PROFILE"], Some(env_mods), + false, ) .await } @@ -468,6 +521,7 @@ async fn test_bedrock_provider_bearer_token() -> Result<()> { BEDROCK_DEFAULT_MODEL, &["AWS_BEARER_TOKEN_BEDROCK", "AWS_REGION"], Some(env_mods), + false, ) .await } @@ -479,6 +533,7 @@ async fn test_databricks_provider() -> Result<()> { DATABRICKS_DEFAULT_MODEL, &["DATABRICKS_HOST", "DATABRICKS_TOKEN"], None, + false, ) .await } @@ -486,7 +541,7 @@ async fn test_databricks_provider() -> Result<()> { #[tokio::test] async fn test_ollama_provider() -> Result<()> { // qwen3-vl supports text, tools, and vision (needed for image test) - test_provider("Ollama", "qwen3-vl", &["OLLAMA_HOST"], None).await + test_provider("Ollama", "qwen3-vl", &["OLLAMA_HOST"], None, false).await } #[tokio::test] @@ -496,6 +551,7 @@ async fn test_anthropic_provider() -> Result<()> { ANTHROPIC_DEFAULT_MODEL, &["ANTHROPIC_API_KEY"], None, + false, ) .await } @@ -507,13 +563,21 @@ async fn test_openrouter_provider() -> Result<()> { OPEN_AI_DEFAULT_MODEL, &["OPENROUTER_API_KEY"], None, + false, ) .await } #[tokio::test] async fn test_google_provider() -> Result<()> { - test_provider("Google", GOOGLE_DEFAULT_MODEL, &["GOOGLE_API_KEY"], None).await + test_provider( + "Google", + GOOGLE_DEFAULT_MODEL, + &["GOOGLE_API_KEY"], + None, + false, + ) + .await } #[tokio::test] @@ -523,6 +587,7 @@ async fn test_snowflake_provider() -> Result<()> { SNOWFLAKE_DEFAULT_MODEL, &["SNOWFLAKE_HOST", "SNOWFLAKE_TOKEN"], None, + false, ) .await } @@ -534,6 +599,7 @@ async fn test_sagemaker_tgi_provider() -> Result<()> { SAGEMAKER_TGI_DEFAULT_MODEL, &["SAGEMAKER_ENDPOINT_NAME"], None, + false, ) .await } @@ -551,12 +617,32 @@ async fn test_litellm_provider() -> Result<()> { ("LITELLM_API_KEY", Some("".to_string())), ]); - test_provider("LiteLLM", LITELLM_DEFAULT_MODEL, &[], Some(env_mods)).await + test_provider("LiteLLM", LITELLM_DEFAULT_MODEL, &[], Some(env_mods), false).await } #[tokio::test] async fn test_xai_provider() -> Result<()> { - test_provider("Xai", XAI_DEFAULT_MODEL, &["XAI_API_KEY"], None).await + test_provider("Xai", XAI_DEFAULT_MODEL, &["XAI_API_KEY"], None, false).await +} + +#[tokio::test] +async fn test_claude_code_provider() -> Result<()> { + if which::which("claude").is_err() { + println!("'claude' CLI not found, skipping test"); + TEST_REPORT.record_skip("claude-code"); + return Ok(()); + } + test_provider("claude-code", CLAUDE_CODE_DEFAULT_MODEL, &[], None, true).await +} + +#[tokio::test] +async fn test_codex_provider() -> Result<()> { + if which::which("codex").is_err() { + println!("'codex' CLI not found, skipping test"); + TEST_REPORT.record_skip("codex"); + return Ok(()); + } + test_provider("codex", CODEX_DEFAULT_MODEL, &[], None, true).await } #[ctor::dtor]