Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions crates/goose/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use super::errors::ProviderError;
use super::formats::anthropic::{
create_request, get_usage, response_to_message, response_to_streaming_message,
};
use super::utils::{get_model, map_http_error_to_provider_error};
use super::utils::{get_model, handle_status_openai_compat, map_http_error_to_provider_error};
use crate::config::declarative_providers::DeclarativeProviderConfig;
use crate::conversation::message::Message;
use crate::model::ModelConfig;
Expand Down Expand Up @@ -273,17 +273,12 @@ impl Provider for AnthropicProvider {
request = request.header(key, value)?;
}

let response = request.response_post(&payload).await.inspect_err(|e| {
let resp = request.response_post(&payload).await.inspect_err(|e| {
let _ = log.error(e);
})?;
let response = handle_status_openai_compat(resp).await.inspect_err(|e| {
let _ = log.error(e);
})?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
let error_json = serde_json::from_str::<Value>(&error_text).ok();
let error = map_http_error_to_provider_error(status, error_json);
let _ = log.error(&error);
return Err(error);
}

let stream = response.bytes_stream().map_err(io::Error::other);

Expand Down
103 changes: 0 additions & 103 deletions crates/goose/src/providers/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,106 +112,3 @@ impl GoogleErrorCode {
}
}
}

#[derive(serde::Deserialize, Debug)]
pub struct OpenAIError {
#[serde(deserialize_with = "code_as_string")]
pub code: Option<String>,
pub message: Option<String>,
#[serde(rename = "type")]
pub error_type: Option<String>,
}

fn code_as_string<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, Visitor};
use std::fmt;

struct CodeVisitor;

impl<'de> Visitor<'de> for CodeVisitor {
type Value = Option<String>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string, a number, null, or none for the code field")
}

fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Some(value.to_string()))
}

fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Some(value.to_string()))
}

fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(None)
}

fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(None)
}

fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_any(CodeVisitor)
}
}

deserializer.deserialize_option(CodeVisitor)
}

impl OpenAIError {
pub fn is_context_length_exceeded(&self) -> bool {
if let Some(code) = &self.code {
code == "context_length_exceeded" || code == "string_above_max_length"
} else {
false
}
}
}

impl std::fmt::Display for OpenAIError {
/// Format the error for display.
/// E.g. {"message": "Invalid API key", "code": "invalid_api_key", "type": "client_error"}
/// would be formatted as "Invalid API key (code: invalid_api_key, type: client_error)"
/// and {"message": "Foo"} as just "Foo", etc.
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(message) = &self.message {
write!(f, "{}", message)?;
}
let mut in_parenthesis = false;
if let Some(code) = &self.code {
write!(f, " (code: {}", code)?;
in_parenthesis = true;
}
if let Some(typ) = &self.error_type {
if in_parenthesis {
write!(f, ", type: {}", typ)?;
} else {
write!(f, " (type: {}", typ)?;
in_parenthesis = true;
}
}
if in_parenthesis {
write!(f, ")")?;
}
Ok(())
}
}
8 changes: 1 addition & 7 deletions crates/goose/src/providers/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,18 +273,12 @@ impl Provider for OllamaProvider {
.api_client
.response_post("v1/chat/completions", &payload)
.await?;
let status = resp.status();
if !status.is_success() {
return Err(super::utils::map_http_error_to_provider_error(status, None));
}
Ok(resp)
handle_status_openai_compat(resp).await
})
.await
.inspect_err(|e| {
let _ = log.error(e);
})?;
let response = handle_status_openai_compat(response).await?;

let stream = response.bytes_stream().map_err(io::Error::other);

Ok(Box::pin(try_stream! {
Expand Down
10 changes: 1 addition & 9 deletions crates/goose/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,20 +333,12 @@ impl Provider for OpenAiProvider {
.api_client
.response_post(&self.base_path, &payload)
.await?;
let status = resp.status();
if !status.is_success() {
return Err(super::utils::map_http_error_to_provider_error(
status, None, // We'll let handle_status_openai_compat parse the error
));
}
Ok(resp)
handle_status_openai_compat(resp).await
})
.await
.inspect_err(|e| {
let _ = log.error(e);
})?;
let response = handle_status_openai_compat(response).await?;

let stream = response.bytes_stream().map_err(io::Error::other);

Ok(Box::pin(try_stream! {
Expand Down
6 changes: 3 additions & 3 deletions crates/goose/src/providers/tetrate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,18 +214,18 @@ impl Provider for TetrateProvider {
&super::utils::ImageFormat::OpenAi,
)?;

// Enable streaming
payload["stream"] = json!(true);
payload["stream_options"] = json!({
"include_usage": true,
});

let response = self
let resp = self
.api_client
.response_post("v1/chat/completions", &payload)
.await?;

let response = handle_status_openai_compat(response).await?;
let response = handle_status_openai_compat(resp).await?;

let stream = response.bytes_stream().map_err(io::Error::other);
let mut log = RequestLog::start(&self.model, &payload)?;

Expand Down
Loading
Loading