Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 96 additions & 3 deletions crates/goose/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
use anyhow::Result;
use async_stream::try_stream;
use async_trait::async_trait;
use axum::http::HeaderMap;
use futures::TryStreamExt;
use reqwest::{Client, StatusCode};
use serde_json::Value;
use std::io;
use std::time::Duration;
use tokio::pin;

use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage};
use tokio_util::io::StreamReader;

use super::base::{ConfigKey, MessageStream, ModelInfo, Provider, ProviderMetadata, ProviderUsage};
use super::errors::ProviderError;
use super::formats::anthropic::{create_request, get_usage, response_to_message};
use super::formats::anthropic::{
create_request, get_usage, response_to_message, response_to_streaming_message,
};
use super::utils::{emit_debug_trace, get_model};
use crate::message::Message;
use crate::model::ModelConfig;
Expand Down Expand Up @@ -195,10 +203,17 @@ impl Provider for AnthropicProvider {
// Parse response
let message = response_to_message(response.clone())?;
let usage = get_usage(&response)?;
tracing::debug!("🔍 Anthropic non-streaming parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}",
usage.input_tokens, usage.output_tokens, usage.total_tokens);

let model = get_model(&response);
emit_debug_trace(&self.model, &payload, &response, &usage);
Ok((message, ProviderUsage::new(model, usage)))
let provider_usage = ProviderUsage::new(model, usage);
tracing::debug!(
"🔍 Anthropic non-streaming returning ProviderUsage: {:?}",
provider_usage
);
Ok((message, provider_usage))
}

/// Fetch supported models from Anthropic; returns Err on failure, Ok(None) if not present
Expand Down Expand Up @@ -232,4 +247,82 @@ impl Provider for AnthropicProvider {
models.sort();
Ok(Some(models))
}

async fn stream(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
let mut payload = create_request(&self.model, system, messages, tools)?;

// Add stream parameter
payload
.as_object_mut()
.unwrap()
.insert("stream".to_string(), Value::Bool(true));

let mut headers = reqwest::header::HeaderMap::new();
headers.insert("x-api-key", self.api_key.parse().unwrap());
headers.insert("anthropic-version", ANTHROPIC_API_VERSION.parse().unwrap());

let is_thinking_enabled = std::env::var("CLAUDE_THINKING_ENABLED").is_ok();
if self.model.model_name.starts_with("claude-3-7-sonnet-") && is_thinking_enabled {
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#extended-output-capabilities-beta
headers.insert("anthropic-beta", "output-128k-2025-02-19".parse().unwrap());
}

if self.model.model_name.starts_with("claude-3-7-sonnet-") {
// https://docs.anthropic.com/en/docs/build-with-claude/tool-use/token-efficient-tool-use
headers.insert(
"anthropic-beta",
"token-efficient-tools-2025-02-19".parse().unwrap(),
);
}

let base_url = url::Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let url = base_url.join("v1/messages").map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;

let response = self
.client
.post(url)
.headers(headers)
.json(&payload)
.send()
.await?;

if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(ProviderError::RequestFailed(format!(
"Streaming request failed with status: {}. Error: {}",
status, error_text
)));
}

// Map reqwest error to io::Error
let stream = response.bytes_stream().map_err(io::Error::other);

let model_config = self.model.clone();
// Wrap in a line decoder and yield lines inside the stream
Ok(Box::pin(try_stream! {
let stream_reader = StreamReader::new(stream);
let framed = tokio_util::codec::FramedRead::new(stream_reader, tokio_util::codec::LinesCodec::new()).map_err(anyhow::Error::from);

let message_stream = response_to_streaming_message(framed);
pin!(message_stream);
while let Some(message) = futures::StreamExt::next(&mut message_stream).await {
let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?;
super::utils::emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default());
yield (message, usage);
}
}))
}

fn supports_streaming(&self) -> bool {
true
}
}
Loading
Loading