Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 81 additions & 2 deletions crates/goose/src/providers/openai.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use anyhow::Result;
use async_trait::async_trait;
use reqwest::Client;
use serde_json::Value;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::time::Duration;

Expand Down Expand Up @@ -123,6 +123,79 @@ impl OpenAiProvider {
}
}

/// Update the request when using anthropic model.
/// For anthropic model, we can enable prompt caching to save cost. Since openrouter is the OpenAI compatible
/// endpoint, we need to modify the open ai request to have anthropic cache control field.
fn update_request_for_anthropic(original_payload: &Value) -> Value {
let mut payload = original_payload.clone();

if let Some(messages_spec) = payload
.as_object_mut()
.and_then(|obj| obj.get_mut("messages"))
.and_then(|messages| messages.as_array_mut())
{
// Add "cache_control" to the last and second-to-last "user" messages.
// During each turn, we mark the final message with cache_control so the conversation can be
// incrementally cached. The second-to-last user message is also marked for caching with the
// cache_control parameter, so that this checkpoint can read from the previous cache.
let mut user_count = 0;
for message in messages_spec.iter_mut().rev() {
if message.get("role") == Some(&json!("user")) {
if let Some(content) = message.get_mut("content") {
if let Some(content_str) = content.as_str() {
*content = json!([{
"type": "text",
"text": content_str,
"cache_control": { "type": "ephemeral" }
}]);
}
}
user_count += 1;
if user_count >= 2 {
break;
}
}
}

// Update the system message to have cache_control field.
if let Some(system_message) = messages_spec
.iter_mut()
.find(|msg| msg.get("role") == Some(&json!("system")))
{
if let Some(content) = system_message.get_mut("content") {
if let Some(content_str) = content.as_str() {
*system_message = json!({
"role": "system",
"content": [{
"type": "text",
"text": content_str,
"cache_control": { "type": "ephemeral" }
}]
});
}
}
}
}

if let Some(tools_spec) = payload
.as_object_mut()
.and_then(|obj| obj.get_mut("tools"))
.and_then(|tools| tools.as_array_mut())
{
// Add "cache_control" to the last tool spec, if any. This means that all tool definitions,
// will be cached as a single prefix.
if let Some(last_tool) = tools_spec.last_mut() {
if let Some(function) = last_tool.get_mut("function") {
function
.as_object_mut()
.unwrap()
.insert("cache_control".to_string(), json!({ "type": "ephemeral" }));
}
}
}
payload
}

#[async_trait]
impl Provider for OpenAiProvider {
fn metadata() -> ProviderMetadata {
Expand Down Expand Up @@ -167,7 +240,13 @@ impl Provider for OpenAiProvider {
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;
let mut payload =
create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;

// Add cache_control for claude models (LiteLLM and other OpenAI-compatible services)
if self.model.model_name.to_lowercase().contains("claude") {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we add a method to the Provider trait for this instead?

payload = update_request_for_anthropic(&payload);
}

// Make request
let response = self.post(payload.clone()).await?;
Expand Down
Loading