Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/goose/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl AnthropicProvider {
})
}

async fn post(&self, headers: HeaderMap, payload: Value) -> Result<Value, ProviderError> {
async fn post(&self, headers: HeaderMap, payload: &Value) -> Result<Value, ProviderError> {
let base_url = url::Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let url = base_url.join("v1/messages").map_err(|e| {
Expand All @@ -84,7 +84,7 @@ impl AnthropicProvider {
.client
.post(url)
.headers(headers)
.json(&payload)
.json(payload)
.send()
.await?;

Expand Down Expand Up @@ -198,10 +198,10 @@ impl Provider for AnthropicProvider {
}

// Make request
let response = self.post(headers, payload.clone()).await?;
let response = self.post(headers, &payload).await?;

// Parse response
let message = response_to_message(response.clone())?;
let message = response_to_message(&response)?;
let usage = get_usage(&response)?;
tracing::debug!("🔍 Anthropic non-streaming parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}",
usage.input_tokens, usage.output_tokens, usage.total_tokens);
Expand Down
8 changes: 4 additions & 4 deletions crates/goose/src/providers/azure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl AzureProvider {
})
}

async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
let mut base_url = url::Url::parse(&self.endpoint)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;

Expand Down Expand Up @@ -143,7 +143,7 @@ impl AzureProvider {
}
}

let response_result = request_builder.json(&payload).send().await;
let response_result = request_builder.json(payload).send().await;

match response_result {
Ok(response) => match handle_response_openai_compat(response).await {
Expand Down Expand Up @@ -249,9 +249,9 @@ impl Provider for AzureProvider {
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;
let response = self.post(payload.clone()).await?;
let response = self.post(&payload).await?;

let message = response_to_message(response.clone())?;
let message = response_to_message(&response)?;
let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
tracing::debug!("Failed to get usage data");
Usage::default()
Expand Down
10 changes: 5 additions & 5 deletions crates/goose/src/providers/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ impl DatabricksProvider {
}
}

async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
// Check if this is an embedding request by looking at the payload structure
let is_embedding = payload.get("input").is_some() && payload.get("messages").is_none();
let path = if is_embedding {
Expand All @@ -284,7 +284,7 @@ impl DatabricksProvider {
format!("serving-endpoints/{}/invocations", self.model.model_name)
};

match self.post_with_retry(path.as_str(), &payload).await {
match self.post_with_retry(path.as_str(), payload).await {
Ok(res) => res.json().await.map_err(|_| {
ProviderError::RequestFailed("Response body is not valid JSON".to_string())
}),
Expand Down Expand Up @@ -451,10 +451,10 @@ impl Provider for DatabricksProvider {
.expect("payload should have model key")
.remove("model");

let response = self.post(payload.clone()).await?;
let response = self.post(&payload).await?;

// Parse response
let message = response_to_message(response.clone())?;
let message = response_to_message(&response)?;
let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
tracing::debug!("Failed to get usage data");
Usage::default()
Expand Down Expand Up @@ -619,7 +619,7 @@ impl EmbeddingCapable for DatabricksProvider {
"input": texts,
});

let response = self.post(request).await?;
let response = self.post(&request).await?;

let embeddings = response["data"]
.as_array()
Expand Down
8 changes: 4 additions & 4 deletions crates/goose/src/providers/formats/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ pub fn format_system(system: &str) -> Value {
}

/// Convert Anthropic's API response to internal Message format
pub fn response_to_message(response: Value) -> Result<Message> {
pub fn response_to_message(response: &Value) -> Result<Message> {
let content_blocks = response
.get(CONTENT_FIELD)
.and_then(|c| c.as_array())
Expand Down Expand Up @@ -690,7 +690,7 @@ mod tests {
}
});

let message = response_to_message(response.clone())?;
let message = response_to_message(&response)?;
let usage = get_usage(&response)?;

if let MessageContent::Text(text) = &message.content[0] {
Expand Down Expand Up @@ -731,7 +731,7 @@ mod tests {
}
});

let message = response_to_message(response.clone())?;
let message = response_to_message(&response)?;
let usage = get_usage(&response)?;

if let MessageContent::ToolRequest(tool_request) = &message.content[0] {
Expand Down Expand Up @@ -781,7 +781,7 @@ mod tests {
}
});

let message = response_to_message(response.clone())?;
let message = response_to_message(&response)?;
let usage = get_usage(&response)?;

assert_eq!(message.content.len(), 3);
Expand Down
22 changes: 11 additions & 11 deletions crates/goose/src/providers/formats/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ pub fn format_tools(tools: &[Tool]) -> anyhow::Result<Vec<Value>> {
}

/// Convert Databricks' API response to internal Message format
pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
let original = response["choices"][0]["message"].clone();
pub fn response_to_message(response: &Value) -> anyhow::Result<Message> {
let original = &response["choices"][0]["message"];
let mut content = Vec::new();

// Handle array-based content
Expand Down Expand Up @@ -737,7 +737,7 @@ mod tests {

// Get the ID from the tool request to use in the response
let tool_id = if let MessageContent::ToolRequest(request) = &messages[2].content[0] {
request.id.clone()
&request.id
} else {
panic!("should be tool request");
};
Expand Down Expand Up @@ -770,7 +770,7 @@ mod tests {

// Get the ID from the tool request to use in the response
let tool_id = if let MessageContent::ToolRequest(request) = &messages[0].content[0] {
request.id.clone()
&request.id
} else {
panic!("should be tool request");
};
Expand Down Expand Up @@ -891,7 +891,7 @@ mod tests {
}
});

let message = response_to_message(response)?;
let message = response_to_message(&response)?;
assert_eq!(message.content.len(), 1);
if let MessageContent::Text(text) = &message.content[0] {
assert_eq!(text.text, "Hello from John Cena!");
Expand All @@ -906,7 +906,7 @@ mod tests {
#[test]
fn test_response_to_message_valid_toolrequest() -> anyhow::Result<()> {
let response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?;
let message = response_to_message(response)?;
let message = response_to_message(&response)?;

assert_eq!(message.content.len(), 1);
if let MessageContent::ToolRequest(request) = &message.content[0] {
Expand All @@ -926,7 +926,7 @@ mod tests {
response["choices"][0]["message"]["tool_calls"][0]["function"]["name"] =
json!("invalid fn");

let message = response_to_message(response)?;
let message = response_to_message(&response)?;

if let MessageContent::ToolRequest(request) = &message.content[0] {
match &request.tool_call {
Expand All @@ -948,7 +948,7 @@ mod tests {
response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] =
json!("invalid json {");

let message = response_to_message(response)?;
let message = response_to_message(&response)?;

if let MessageContent::ToolRequest(request) = &message.content[0] {
match &request.tool_call {
Expand All @@ -970,7 +970,7 @@ mod tests {
response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] =
serde_json::Value::String("".to_string());

let message = response_to_message(response)?;
let message = response_to_message(&response)?;

if let MessageContent::ToolRequest(request) = &message.content[0] {
let tool_call = request.tool_call.as_ref().unwrap();
Expand Down Expand Up @@ -1107,7 +1107,7 @@ mod tests {
}]
});

let message = response_to_message(response)?;
let message = response_to_message(&response)?;
assert_eq!(message.content.len(), 2);

if let MessageContent::Thinking(thinking) = &message.content[0] {
Expand Down Expand Up @@ -1154,7 +1154,7 @@ mod tests {
}]
});

let message = response_to_message(response)?;
let message = response_to_message(&response)?;
assert_eq!(message.content.len(), 2);

if let MessageContent::RedactedThinking(redacted) = &message.content[0] {
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/providers/formats/gcpvertexai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ pub fn create_request(
/// * `Result<Message>` - Converted message
pub fn response_to_message(response: Value, request_context: RequestContext) -> Result<Message> {
match request_context.provider() {
ModelProvider::Anthropic => anthropic::response_to_message(response),
ModelProvider::Anthropic => anthropic::response_to_message(&response),
ModelProvider::Google => google::response_to_message(response),
}
}
Expand Down
14 changes: 7 additions & 7 deletions crates/goose/src/providers/formats/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ pub fn format_tools(tools: &[Tool]) -> anyhow::Result<Vec<Value>> {
}

/// Convert OpenAI's API response to internal Message format
pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
let original = response["choices"][0]["message"].clone();
pub fn response_to_message(response: &Value) -> anyhow::Result<Message> {
let original = &response["choices"][0]["message"];
let mut content = Vec::new();

if let Some(text) = original.get("content") {
Expand Down Expand Up @@ -910,7 +910,7 @@ mod tests {
}
});

let message = response_to_message(response)?;
let message = response_to_message(&response)?;
assert_eq!(message.content.len(), 1);
if let MessageContent::Text(text) = &message.content[0] {
assert_eq!(text.text, "Hello from John Cena!");
Expand All @@ -925,7 +925,7 @@ mod tests {
#[test]
fn test_response_to_message_valid_toolrequest() -> anyhow::Result<()> {
let response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?;
let message = response_to_message(response)?;
let message = response_to_message(&response)?;

assert_eq!(message.content.len(), 1);
if let MessageContent::ToolRequest(request) = &message.content[0] {
Expand All @@ -945,7 +945,7 @@ mod tests {
response["choices"][0]["message"]["tool_calls"][0]["function"]["name"] =
json!("invalid fn");

let message = response_to_message(response)?;
let message = response_to_message(&response)?;

if let MessageContent::ToolRequest(request) = &message.content[0] {
match &request.tool_call {
Expand All @@ -967,7 +967,7 @@ mod tests {
response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] =
json!("invalid json {");

let message = response_to_message(response)?;
let message = response_to_message(&response)?;

if let MessageContent::ToolRequest(request) = &message.content[0] {
match &request.tool_call {
Expand All @@ -989,7 +989,7 @@ mod tests {
response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] =
serde_json::Value::String("".to_string());

let message = response_to_message(response)?;
let message = response_to_message(&response)?;

if let MessageContent::ToolRequest(request) = &message.content[0] {
let tool_call = request.tool_call.as_ref().unwrap();
Expand Down
8 changes: 4 additions & 4 deletions crates/goose/src/providers/formats/snowflake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ pub fn parse_streaming_response(sse_data: &str) -> Result<Message> {
}

/// Convert Snowflake's API response to internal Message format
pub fn response_to_message(response: Value) -> Result<Message> {
pub fn response_to_message(response: &Value) -> Result<Message> {
let mut message = Message::assistant();

let content_list = response.get("content_list").and_then(|cl| cl.as_array());
Expand Down Expand Up @@ -380,7 +380,7 @@ mod tests {
}
});

let message = response_to_message(response.clone())?;
let message = response_to_message(&response)?;
let usage = get_usage(&response)?;

if let MessageContent::Text(text) = &message.content[0] {
Expand Down Expand Up @@ -417,7 +417,7 @@ mod tests {
}
});

let message = response_to_message(response.clone())?;
let message = response_to_message(&response)?;
let usage = get_usage(&response)?;

if let MessageContent::ToolRequest(tool_request) = &message.content[0] {
Expand Down Expand Up @@ -625,7 +625,7 @@ data: {"id":"a9537c2c-2017-4906-9817-2456168d89fa","model":"claude-3-5-sonnet","
}
});

let message = response_to_message(response.clone())?;
let message = response_to_message(&response)?;

// Should have both text and tool request content
assert_eq!(message.content.len(), 2);
Expand Down
18 changes: 11 additions & 7 deletions crates/goose/src/providers/gcpvertexai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,14 +281,14 @@ impl GcpVertexAIProvider {
) -> Result<Url, GcpVertexAIError> {
// Create host URL for the specified location
let host_url = if self.location == location {
self.host.clone()
&self.host
} else {
// Only allocate a new string if location differs
self.host.replace(&self.location, location)
&self.host.replace(&self.location, location)
};

let base_url =
Url::parse(&host_url).map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))?;
Url::parse(host_url).map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))?;

// Determine endpoint based on provider type
let endpoint = match provider {
Expand Down Expand Up @@ -470,10 +470,14 @@ impl GcpVertexAIProvider {
/// # Arguments
/// * `payload` - The request payload to send
/// * `context` - Request context containing model information
async fn post(&self, payload: Value, context: &RequestContext) -> Result<Value, ProviderError> {
async fn post(
&self,
payload: &Value,
context: &RequestContext,
) -> Result<Value, ProviderError> {
// Try with user-specified location first
let result = self
.post_with_location(&payload, context, &self.location)
.post_with_location(payload, context, &self.location)
.await;

// If location is already the known location for the model or request succeeded, return result
Expand All @@ -492,7 +496,7 @@ impl GcpVertexAIProvider {
"Trying known location {known_location} for {model_name} instead of {configured_location}: {msg}"
);

self.post_with_location(&payload, context, &known_location)
self.post_with_location(payload, context, &known_location)
.await
}
// For any other error, return the original result
Expand Down Expand Up @@ -609,7 +613,7 @@ impl Provider for GcpVertexAIProvider {
let (request, context) = create_request(&self.model, system, messages, tools)?;

// Send request and process response
let response = self.post(request.clone(), &context).await?;
let response = self.post(&request, &context).await?;
let usage = get_usage(&response, &context)?;

emit_debug_trace(&self.model, &request, &response, &usage);
Expand Down
Loading
Loading