From 7768f3698343df8189d966c71f1f58216c289291 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 20 Dec 2024 13:37:04 +1100 Subject: [PATCH 1/6] encapsulate the concat_tool_response_content by creating delegating functions --- crates/goose/src/providers/databricks.rs | 7 ++++- crates/goose/src/providers/groq.rs | 11 +++++--- crates/goose/src/providers/ollama.rs | 3 +-- crates/goose/src/providers/openai.rs | 3 +-- crates/goose/src/providers/openai_utils.rs | 30 ++++++++++++++++++++++ 5 files changed, 46 insertions(+), 8 deletions(-) diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 6a2670392ea1..98950dbee963 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -81,7 +81,12 @@ impl Provider for DatabricksProvider { tools: &[Tool], ) -> Result<(Message, ProviderUsage)> { // Prepare messages and tools - let messages_spec = messages_to_openai_spec(messages, &self.config.image_format, false); + let concat_tool_response_contents = false; + let messages_spec = messages_to_openai_spec( + messages, + &self.config.image_format, + concat_tool_response_contents, + ); let tools_spec = if !tools.is_empty() { tools_to_openai_spec(tools)? } else { diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index b0d1be1dc2a7..88c803938d59 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -2,7 +2,8 @@ use crate::message::Message; use crate::providers::base::{Provider, ProviderUsage, Usage}; use crate::providers::configs::{GroqProviderConfig, ModelConfig, ProviderModelConfig}; use crate::providers::openai_utils::{ - create_openai_request_payload, get_openai_usage, openai_response_to_message, + create_openai_request_payload_with_concat_response_content, get_openai_usage, + openai_response_to_message, }; use crate::providers::utils::{get_model, handle_response}; use async_trait::async_trait; @@ -61,8 +62,12 @@ impl Provider for GroqProvider { messages: &[Message], tools: &[Tool], ) -> anyhow::Result<(Message, ProviderUsage)> { - let payload = - create_openai_request_payload(&self.config.model, system, messages, tools, true)?; + let payload = create_openai_request_payload_with_concat_response_content( + &self.config.model, + system, + messages, + tools, + )?; let response = self.post(payload).await?; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 540f9291f862..84cc0b4eb2b1 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -57,8 +57,7 @@ impl Provider for OllamaProvider { messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage)> { - let payload = - create_openai_request_payload(&self.config.model, system, messages, tools, false)?; + let payload = create_openai_request_payload(&self.config.model, system, messages, tools)?; let response = self.post(payload).await?; diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index f735b2209c64..54884b040dec 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -69,8 +69,7 @@ impl Provider for OpenAiProvider { tools: &[Tool], ) -> Result<(Message, ProviderUsage)> { // Not checking for o1 model here since system message is not supported by o1 - let payload = - create_openai_request_payload(&self.config.model, system, messages, tools, false)?; + let payload = create_openai_request_payload(&self.config.model, system, messages, tools)?; // Make request let response = self.post(payload).await?; diff --git a/crates/goose/src/providers/openai_utils.rs b/crates/goose/src/providers/openai_utils.rs index 0be8d8917448..b822cd2edbb3 100644 --- a/crates/goose/src/providers/openai_utils.rs +++ b/crates/goose/src/providers/openai_utils.rs @@ -256,6 +256,21 @@ pub fn create_openai_request_payload( system: &str, messages: &[Message], tools: &[Tool], +) -> anyhow::Result { + create_openai_request_payload_handling_concat_response_content( + model_config, + system, + messages, + tools, + false, + ) +} + +fn create_openai_request_payload_handling_concat_response_content( + model_config: &ModelConfig, + system: &str, + messages: &[Message], + tools: &[Tool], concat_tool_response_contents: bool, ) -> anyhow::Result { let system_message = json!({ @@ -299,6 +314,21 @@ pub fn create_openai_request_payload( Ok(payload) } +pub fn create_openai_request_payload_with_concat_response_content( + model_config: &ModelConfig, + system: &str, + messages: &[Message], + tools: &[Tool], +) -> anyhow::Result { + create_openai_request_payload_handling_concat_response_content( + model_config, + system, + messages, + tools, + true, + ) +} + pub fn check_openai_context_length_error(error: &Value) -> Option { let code = error.get("code")?.as_str()?; if code == "context_length_exceeded" || code == "string_above_max_length" { From 985e3ccd0866722107215449a72dcc480e81e1f3 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 20 Dec 2024 13:49:02 +1100 Subject: [PATCH 2/6] include context_limit and estimate_factor in groq and google provider config --- crates/goose-server/src/configuration.rs | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index c6435db591f2..bbece127666e 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -88,6 +88,10 @@ pub enum ProviderSettings { temperature: Option, #[serde(default)] max_tokens: Option, + #[serde(default)] + context_limit: Option, + #[serde(default)] + estimate_factor: Option, }, Groq { #[serde(default = "default_groq_host")] @@ -99,6 +103,10 @@ pub enum ProviderSettings { temperature: Option, #[serde(default)] max_tokens: Option, + #[serde(default)] + context_limit: Option, + #[serde(default)] + estimate_factor: Option, }, } @@ -174,12 +182,16 @@ impl ProviderSettings { model, temperature, max_tokens, + context_limit, + estimate_factor, } => ProviderConfig::Google(GoogleProviderConfig { host, api_key, model: ModelConfig::new(model) .with_temperature(temperature) - .with_max_tokens(max_tokens), + .with_max_tokens(max_tokens) + .with_context_limit(context_limit) + .with_estimate_factor(estimate_factor), }), ProviderSettings::Groq { host, @@ -187,12 +199,16 @@ impl ProviderSettings { model, temperature, max_tokens, + context_limit, + estimate_factor, } => ProviderConfig::Groq(GroqProviderConfig { host, api_key, model: ModelConfig::new(model) .with_temperature(temperature) - .with_max_tokens(max_tokens), + .with_max_tokens(max_tokens) + .with_context_limit(context_limit) + .with_estimate_factor(estimate_factor), }), } } From 6d1a9cfc552750ee94969cdcbd97e1ca0f3c7e21 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 20 Dec 2024 14:14:39 +1100 Subject: [PATCH 3/6] moved get_usage in the traits --- crates/goose/src/providers/anthropic.rs | 52 ++++++++++++------------ crates/goose/src/providers/base.rs | 3 ++ crates/goose/src/providers/databricks.rs | 14 +++---- crates/goose/src/providers/google.rs | 42 +++++++++---------- crates/goose/src/providers/groq.rs | 10 ++--- crates/goose/src/providers/mock.rs | 5 +++ crates/goose/src/providers/ollama.rs | 10 ++--- crates/goose/src/providers/openai.rs | 10 ++--- 8 files changed, 77 insertions(+), 69 deletions(-) diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index c23d77df7e7f..61f1094b23f0 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -33,29 +33,6 @@ impl AnthropicProvider { Ok(Self { client, config }) } - fn get_usage(data: &Value) -> Result { - // Extract usage data if available - if let Some(usage) = data.get("usage") { - let input_tokens = usage - .get("input_tokens") - .and_then(|v| v.as_u64()) - .map(|v| v as i32); - let output_tokens = usage - .get("output_tokens") - .and_then(|v| v.as_u64()) - .map(|v| v as i32); - let total_tokens = match (input_tokens, output_tokens) { - (Some(i), Some(o)) => Some(i + o), - _ => None, - }; - - Ok(Usage::new(input_tokens, output_tokens, total_tokens)) - } else { - // If no usage data, return None for all values - Ok(Usage::new(None, None, None)) - } - } - fn tools_to_anthropic_spec(tools: &[Tool]) -> Vec { let mut unique_tools = HashSet::new(); let mut tool_specs = Vec::new(); @@ -212,6 +189,10 @@ impl AnthropicProvider { #[async_trait] impl Provider for AnthropicProvider { + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() + } + async fn complete( &self, system: &str, @@ -261,15 +242,34 @@ impl Provider for AnthropicProvider { // Parse response let message = Self::parse_anthropic_response(response.clone())?; - let usage = Self::get_usage(&response)?; + let usage = self.get_usage(&response)?; let model = get_model(&response); let cost = cost(&usage, &model_pricing_for(&model)); Ok((message, ProviderUsage::new(model, usage, cost))) } - fn get_model_config(&self) -> &ModelConfig { - self.config.model_config() + fn get_usage(&self, data: &Value) -> Result { + // Extract usage data if available + if let Some(usage) = data.get("usage") { + let input_tokens = usage + .get("input_tokens") + .and_then(|v| v.as_u64()) + .map(|v| v as i32); + let output_tokens = usage + .get("output_tokens") + .and_then(|v| v.as_u64()) + .map(|v| v as i32); + let total_tokens = match (input_tokens, output_tokens) { + (Some(i), Some(o)) => Some(i + o), + _ => None, + }; + + Ok(Usage::new(input_tokens, output_tokens, total_tokens)) + } else { + // If no usage data, return None for all values + Ok(Usage::new(None, None, None)) + } } } diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index a6ee06eabbc8..fa52442c458c 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -48,6 +48,7 @@ impl Usage { } use async_trait::async_trait; +use serde_json::Value; /// Base trait for AI providers (OpenAI, Anthropic, etc) #[async_trait] @@ -70,6 +71,8 @@ pub trait Provider: Send + Sync { messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage)>; + + fn get_usage(&self, data: &Value) -> Result; } #[cfg(test)] diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 98950dbee963..ab7434bd8794 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -48,10 +48,6 @@ impl DatabricksProvider { } } - fn get_usage(data: &Value) -> Result { - get_openai_usage(data) - } - async fn post(&self, payload: Value) -> Result { let url = format!( "{}/serving-endpoints/{}/invocations", @@ -74,6 +70,10 @@ impl DatabricksProvider { #[async_trait] impl Provider for DatabricksProvider { + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() + } + async fn complete( &self, system: &str, @@ -136,15 +136,15 @@ impl Provider for DatabricksProvider { // Parse response let message = openai_response_to_message(response.clone())?; - let usage = Self::get_usage(&response)?; + let usage = self.get_usage(&response)?; let model = get_model(&response); let cost = cost(&usage, &model_pricing_for(&model)); Ok((message, ProviderUsage::new(model, usage, cost))) } - fn get_model_config(&self) -> &ModelConfig { - self.config.model_config() + fn get_usage(&self, data: &Value) -> Result { + get_openai_usage(data) } } diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index d96b681add87..97d9a92c35f2 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -28,27 +28,6 @@ impl GoogleProvider { Ok(Self { client, config }) } - fn get_usage(&self, data: &Value) -> anyhow::Result { - if let Some(usage_meta_data) = data.get("usageMetadata") { - let input_tokens = usage_meta_data - .get("promptTokenCount") - .and_then(|v| v.as_u64()) - .map(|v| v as i32); - let output_tokens = usage_meta_data - .get("candidatesTokenCount") - .and_then(|v| v.as_u64()) - .map(|v| v as i32); - let total_tokens = usage_meta_data - .get("totalTokenCount") - .and_then(|v| v.as_u64()) - .map(|v| v as i32); - Ok(Usage::new(input_tokens, output_tokens, total_tokens)) - } else { - // If no usage data, return None for all values - Ok(Usage::new(None, None, None)) - } - } - async fn post(&self, payload: Value) -> anyhow::Result { let url = format!( "{}/v1beta/models/{}:generateContent?key={}", @@ -343,6 +322,27 @@ impl Provider for GoogleProvider { let provider_usage = ProviderUsage::new(model, usage, None); Ok((message, provider_usage)) } + + fn get_usage(&self, data: &Value) -> anyhow::Result { + if let Some(usage_meta_data) = data.get("usageMetadata") { + let input_tokens = usage_meta_data + .get("promptTokenCount") + .and_then(|v| v.as_u64()) + .map(|v| v as i32); + let output_tokens = usage_meta_data + .get("candidatesTokenCount") + .and_then(|v| v.as_u64()) + .map(|v| v as i32); + let total_tokens = usage_meta_data + .get("totalTokenCount") + .and_then(|v| v.as_u64()) + .map(|v| v as i32); + Ok(Usage::new(input_tokens, output_tokens, total_tokens)) + } else { + // If no usage data, return None for all values + Ok(Usage::new(None, None, None)) + } + } } #[cfg(test)] // Only compiles this module when running tests diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index 88c803938d59..52605836f116 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -29,10 +29,6 @@ impl GroqProvider { Ok(Self { client, config }) } - fn get_usage(data: &Value) -> anyhow::Result { - get_openai_usage(data) - } - async fn post(&self, payload: Value) -> anyhow::Result { let url = format!( "{}/openai/v1/chat/completions", @@ -72,11 +68,15 @@ impl Provider for GroqProvider { let response = self.post(payload).await?; let message = openai_response_to_message(response.clone())?; - let usage = Self::get_usage(&response)?; + let usage = self.get_usage(&response)?; let model = get_model(&response); Ok((message, ProviderUsage::new(model, usage, None))) } + + fn get_usage(&self, data: &Value) -> anyhow::Result { + get_openai_usage(data) + } } #[cfg(test)] diff --git a/crates/goose/src/providers/mock.rs b/crates/goose/src/providers/mock.rs index 270870b59dcd..fa84a63aff57 100644 --- a/crates/goose/src/providers/mock.rs +++ b/crates/goose/src/providers/mock.rs @@ -6,6 +6,7 @@ use anyhow::Result; use async_trait::async_trait; use mcp_core::tool::Tool; use rust_decimal_macros::dec; +use serde_json::Value; use std::sync::Arc; use std::sync::Mutex; @@ -60,4 +61,8 @@ impl Provider for MockProvider { )) } } + + fn get_usage(&self, data: &Value) -> Result { + Ok(Usage::new(None, None, None)) + } } diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 84cc0b4eb2b1..8b07333ab98c 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -29,10 +29,6 @@ impl OllamaProvider { Ok(Self { client, config }) } - fn get_usage(data: &Value) -> Result { - get_openai_usage(data) - } - async fn post(&self, payload: Value) -> Result { let url = format!( "{}/v1/chat/completions", @@ -63,12 +59,16 @@ impl Provider for OllamaProvider { // Parse response let message = openai_response_to_message(response.clone())?; - let usage = Self::get_usage(&response)?; + let usage = self.get_usage(&response)?; let model = get_model(&response); let cost = None; Ok((message, ProviderUsage::new(model, usage, cost))) } + + fn get_usage(&self, data: &Value) -> Result { + get_openai_usage(data) + } } #[cfg(test)] diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 54884b040dec..6f3109585b1a 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -34,10 +34,6 @@ impl OpenAiProvider { Ok(Self { client, config }) } - fn get_usage(data: &Value) -> Result { - get_openai_usage(data) - } - async fn post(&self, payload: Value) -> Result { let url = format!( "{}/v1/chat/completions", @@ -84,12 +80,16 @@ impl Provider for OpenAiProvider { // Parse response let message = openai_response_to_message(response.clone())?; - let usage = Self::get_usage(&response)?; + let usage = self.get_usage(&response)?; let model = get_model(&response); let cost = cost(&usage, &model_pricing_for(&model)); Ok((message, ProviderUsage::new(model, usage, cost))) } + + fn get_usage(&self, data: &Value) -> Result { + get_openai_usage(data) + } } #[cfg(test)] From 1436d75a24291ad6486ef2d00277e38fb5022459 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 20 Dec 2024 15:03:26 +1100 Subject: [PATCH 4/6] make small change to trigger build --- crates/goose/src/providers.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index 6f2fb5b9152f..6d564eeb8e8f 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -12,6 +12,7 @@ pub mod utils; pub mod google; pub mod groq; + #[cfg(test)] pub mod mock; #[cfg(test)] From 53074199113008880da57d7105e256f7a6d7a336 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 20 Dec 2024 15:13:07 +1100 Subject: [PATCH 5/6] fixed compilation error --- crates/goose-server/src/routes/reply.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index ae4f89de418b..c6bd4896441b 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -421,6 +421,10 @@ mod tests { fn get_model_config(&self) -> &ModelConfig { &self.model_config } + + fn get_usage(&self, data: &Value) -> anyhow::Result { + Ok(Usage::new(None, None, None)) + } } #[test] From c06ac3fc4126edff94e01717ade95fecf965e0a9 Mon Sep 17 00:00:00 2001 From: Zaki Ali Date: Fri, 20 Dec 2024 16:42:16 -0800 Subject: [PATCH 6/6] Fix checking if tools are empty --- crates/goose/src/providers/openai_utils.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/crates/goose/src/providers/openai_utils.rs b/crates/goose/src/providers/openai_utils.rs index b822cd2edbb3..b30b53189d0c 100644 --- a/crates/goose/src/providers/openai_utils.rs +++ b/crates/goose/src/providers/openai_utils.rs @@ -283,7 +283,11 @@ fn create_openai_request_payload_handling_concat_response_content( &ImageFormat::OpenAi, concat_tool_response_contents, ); - let tools_spec = tools_to_openai_spec(tools)?; + let tools_spec = if !tools.is_empty() { + tools_to_openai_spec(tools)? + } else { + vec![] + }; let mut messages_array = vec![system_message]; messages_array.extend(messages_spec);