diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 9884d147bffc..f3381d27beab 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -1,7 +1,7 @@ use anyhow::Result; use async_trait::async_trait; use reqwest::Client; -use serde_json::Value; +use serde_json::{json, Value}; use std::collections::HashMap; use std::time::Duration; @@ -123,6 +123,79 @@ impl OpenAiProvider { } } +/// Update the request when using anthropic model. +/// For anthropic model, we can enable prompt caching to save cost. Since openrouter is the OpenAI compatible +/// endpoint, we need to modify the open ai request to have anthropic cache control field. +fn update_request_for_anthropic(original_payload: &Value) -> Value { + let mut payload = original_payload.clone(); + + if let Some(messages_spec) = payload + .as_object_mut() + .and_then(|obj| obj.get_mut("messages")) + .and_then(|messages| messages.as_array_mut()) + { + // Add "cache_control" to the last and second-to-last "user" messages. + // During each turn, we mark the final message with cache_control so the conversation can be + // incrementally cached. The second-to-last user message is also marked for caching with the + // cache_control parameter, so that this checkpoint can read from the previous cache. + let mut user_count = 0; + for message in messages_spec.iter_mut().rev() { + if message.get("role") == Some(&json!("user")) { + if let Some(content) = message.get_mut("content") { + if let Some(content_str) = content.as_str() { + *content = json!([{ + "type": "text", + "text": content_str, + "cache_control": { "type": "ephemeral" } + }]); + } + } + user_count += 1; + if user_count >= 2 { + break; + } + } + } + + // Update the system message to have cache_control field. + if let Some(system_message) = messages_spec + .iter_mut() + .find(|msg| msg.get("role") == Some(&json!("system"))) + { + if let Some(content) = system_message.get_mut("content") { + if let Some(content_str) = content.as_str() { + *system_message = json!({ + "role": "system", + "content": [{ + "type": "text", + "text": content_str, + "cache_control": { "type": "ephemeral" } + }] + }); + } + } + } + } + + if let Some(tools_spec) = payload + .as_object_mut() + .and_then(|obj| obj.get_mut("tools")) + .and_then(|tools| tools.as_array_mut()) + { + // Add "cache_control" to the last tool spec, if any. This means that all tool definitions, + // will be cached as a single prefix. + if let Some(last_tool) = tools_spec.last_mut() { + if let Some(function) = last_tool.get_mut("function") { + function + .as_object_mut() + .unwrap() + .insert("cache_control".to_string(), json!({ "type": "ephemeral" })); + } + } + } + payload +} + #[async_trait] impl Provider for OpenAiProvider { fn metadata() -> ProviderMetadata { @@ -167,7 +240,13 @@ impl Provider for OpenAiProvider { messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; + let mut payload = + create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; + + // Add cache_control for claude models (LiteLLM and other OpenAI-compatible services) + if self.model.model_name.to_lowercase().contains("claude") { + payload = update_request_for_anthropic(&payload); + } // Make request let response = self.post(payload.clone()).await?;