From 3b8ce1d3871a1423071b98f82aacb266692e1e17 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 18 Dec 2024 14:26:15 +1100 Subject: [PATCH 01/22] added groq provider --- crates/goose-cli/src/commands/configure.rs | 3 + crates/goose-cli/src/profile.rs | 14 +++- crates/goose-server/src/configuration.rs | 37 +++++++++- crates/goose-server/src/state.rs | 8 +++ crates/goose/src/providers.rs | 1 + crates/goose/src/providers/configs.rs | 14 ++++ crates/goose/src/providers/factory.rs | 3 + crates/goose/src/providers/groq.rs | 80 ++++++++++++++++++++++ crates/goose/src/providers/openai.rs | 2 +- crates/goose/src/token_counter.rs | 1 + 10 files changed, 158 insertions(+), 5 deletions(-) create mode 100644 crates/goose/src/providers/groq.rs diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index 1308513ccf40..da079d043656 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -48,6 +48,7 @@ pub async fn handle_configure( ("ollama", "Ollama", "Local open source models"), ("anthropic", "Anthropic", "Claude models"), ("google", "Google Gemini", "Gemini models"), + ("groq", "Groq", "AI models"), ]) .interact()? .to_string() @@ -159,6 +160,7 @@ pub fn get_recommended_model(provider_name: &str) -> &str { "ollama" => OLLAMA_MODEL, "anthropic" => "claude-3-5-sonnet-2", "google" => "gemini-1.5-flash", + "groq" => "llama3-70b-8192", _ => panic!("Invalid provider name"), } } @@ -170,6 +172,7 @@ pub fn get_required_keys(provider_name: &str) -> Vec<&'static str> { "ollama" => vec!["OLLAMA_HOST"], "anthropic" => vec!["ANTHROPIC_API_KEY"], // Removed ANTHROPIC_HOST since we use a fixed endpoint "google" => vec!["GOOGLE_API_KEY"], + "groq" => vec!["GROQ_API_KEY"], _ => panic!("Invalid provider name"), } } diff --git a/crates/goose-cli/src/profile.rs b/crates/goose-cli/src/profile.rs index 6e03f6b387cc..c959ecb7945f 100644 --- a/crates/goose-cli/src/profile.rs +++ b/crates/goose-cli/src/profile.rs @@ -2,7 +2,7 @@ use anyhow::Result; use goose::key_manager::{get_keyring_secret, KeyRetrievalStrategy}; use goose::providers::configs::{ AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, GoogleProviderConfig, - ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig, + ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig, GroqProviderConfig, }; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -130,7 +130,17 @@ pub fn get_provider_config(provider_name: &str, profile: Profile) -> ProviderCon .expect("GOOGLE_API_KEY not available in env or the keychain\nSet an env var or rerun `goose configure`"); ProviderConfig::Google(GoogleProviderConfig { - host: "https://generativelanguage.googleapis.com".to_string(), // Default Anthropic API endpoint + host: "https://generativelanguage.googleapis.com".to_string(), + api_key, + model: model_config, + }) + } + "groq" => { + let api_key = get_keyring_secret("GROQ_API_KEY", KeyRetrievalStrategy::Both) + .expect("GROQ_API_KEY not available in env or the keychain\nSet an env var or rerun `goose configure`"); + + ProviderConfig::Google(GoogleProviderConfig { + host: "https://api.groq.com".to_string(), api_key, model: model_config, }) diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index de47633013a1..684474072421 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -1,13 +1,13 @@ use crate::error::{to_env_var, ConfigError}; use config::{Config, Environment}; -use goose::providers::configs::GoogleProviderConfig; +use goose::providers::configs::{GoogleProviderConfig, GroqProviderConfig}; use goose::providers::{ configs::{ DatabricksAuth, DatabricksProviderConfig, ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig, }, factory::ProviderType, - google, ollama, + google, ollama, groq, utils::ImageFormat, }; use serde::Deserialize; @@ -88,6 +88,17 @@ pub enum ProviderSettings { #[serde(default)] max_tokens: Option, }, + Groq { + #[serde(default = "default_groq_host")] + host: String, + api_key: String, + #[serde(default = "default_groq_model")] + model: String, + #[serde(default)] + temperature: Option, + #[serde(default)] + max_tokens: Option, + }, } impl ProviderSettings { @@ -99,6 +110,7 @@ impl ProviderSettings { ProviderSettings::Databricks { .. } => ProviderType::Databricks, ProviderSettings::Ollama { .. } => ProviderType::Ollama, ProviderSettings::Google { .. } => ProviderType::Google, + ProviderSettings::Groq { .. } => ProviderType::Groq, } } @@ -168,6 +180,19 @@ impl ProviderSettings { .with_temperature(temperature) .with_max_tokens(max_tokens), }), + ProviderSettings::Groq { + host, + api_key, + model, + temperature, + max_tokens, + } => ProviderConfig::Groq(GroqProviderConfig { + host, + api_key, + model: ModelConfig::new(model) + .with_temperature(temperature) + .with_max_tokens(max_tokens), + }), } } } @@ -267,6 +292,14 @@ fn default_google_model() -> String { google::GOOGLE_DEFAULT_MODEL.to_string() } +fn default_groq_host() -> String { + groq::GROQ_API_HOST.to_string() +} + +fn default_groq_model() -> String { + groq::GROQ_DEFAULT_MODEL.to_string() +} + fn default_image_format() -> ImageFormat { ImageFormat::Anthropic } diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index 446c538dcee4..8973822b7078 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -7,6 +7,7 @@ use goose::{ }; use std::sync::Arc; use tokio::sync::Mutex; +use goose::providers::configs::GroqProviderConfig; /// Shared application state pub struct AppState { @@ -71,6 +72,13 @@ impl Clone for AppState { model: config.model.clone(), }) } + ProviderConfig::Groq(config) => { + ProviderConfig::Groq(GroqProviderConfig { + host: config.host.clone(), + api_key: config.api_key.clone(), + model: config.model.clone(), + }) + } }, agent: self.agent.clone(), secret_key: self.secret_key.clone(), diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index f2d7758aec67..f0e504793ce9 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -12,3 +12,4 @@ pub mod utils; pub mod google; #[cfg(test)] pub mod mock; +pub mod groq; diff --git a/crates/goose/src/providers/configs.rs b/crates/goose/src/providers/configs.rs index 67c49282dc5f..94f6d585d3eb 100644 --- a/crates/goose/src/providers/configs.rs +++ b/crates/goose/src/providers/configs.rs @@ -14,6 +14,7 @@ pub enum ProviderConfig { Ollama(OllamaProviderConfig), Anthropic(AnthropicProviderConfig), Google(GoogleProviderConfig), + Groq(GroqProviderConfig), } /// Configuration for model-specific settings and limits @@ -222,6 +223,19 @@ impl ProviderModelConfig for GoogleProviderConfig { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GroqProviderConfig { + pub host: String, + pub api_key: String, + pub model: ModelConfig, +} + +impl ProviderModelConfig for GroqProviderConfig { + fn model_config(&self) -> &ModelConfig { + &self.model + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OllamaProviderConfig { pub host: String, diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index f5a9c0931dfe..68780905d143 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -5,6 +5,7 @@ use super::{ }; use anyhow::Result; use strum_macros::EnumIter; +use crate::providers::groq::GroqProvider; #[derive(EnumIter, Debug)] pub enum ProviderType { @@ -13,6 +14,7 @@ pub enum ProviderType { Ollama, Anthropic, Google, + Groq } pub fn get_provider(config: ProviderConfig) -> Result> { @@ -26,5 +28,6 @@ pub fn get_provider(config: ProviderConfig) -> Result Ok(Box::new(GoogleProvider::new(google_config)?)), + ProviderConfig::Groq(groq_config) => Ok(Box::new(GroqProvider::new(groq_config)?)), } } diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs new file mode 100644 index 000000000000..72b25d5cd422 --- /dev/null +++ b/crates/goose/src/providers/groq.rs @@ -0,0 +1,80 @@ +use std::time::Duration; +use async_trait::async_trait; +use reqwest::Client; +use serde_json::{json, Map, Value}; +use mcp_core::Tool; +use crate::message::Message; +use crate::providers::base::{Provider, ProviderUsage}; +use crate::providers::configs::{GroqProviderConfig, ModelConfig}; +use crate::providers::google::GoogleProvider; +use crate::providers::utils::unescape_json_values; + +pub const GROQ_API_HOST: &str = "https://api.groq.com"; +pub const GROQ_DEFAULT_MODEL: &str = "llama3-70b-8192"; + +pub struct GroqProvider { + client: Client, + config: GroqProviderConfig, +} + +impl GroqProvider { + pub fn new(config: GroqProviderConfig) -> anyhow::Result { + let client = Client::builder() + .timeout(Duration::from_secs(600)) // 10 minutes timeout + .build()?; + + Ok(Self { client, config }) + } +} + +#[async_trait] +impl Provider for GroqProvider { + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() + } + + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> anyhow::Result<(Message, ProviderUsage)> { + let mut payload = Map::new(); + payload.insert( + "system_instruction".to_string(), + json!({"parts": [{"text": system}]}), + ); + payload.insert( + "contents".to_string(), + json!(self.messages_to_google_spec(&messages)), + ); + if !tools.is_empty() { + payload.insert( + "tools".to_string(), + json!({"functionDeclarations": self.tools_to_google_spec(&tools)}), + ); + } + let mut generation_config = Map::new(); + if let Some(temp) = self.config.model.temperature { + generation_config.insert("temperature".to_string(), json!(temp)); + } + if let Some(tokens) = self.config.model.max_tokens { + generation_config.insert("maxOutputTokens".to_string(), json!(tokens)); + } + if !generation_config.is_empty() { + payload.insert("generationConfig".to_string(), json!(generation_config)); + } + + // Make request + let response = self.post(Value::Object(payload)).await?; + // Parse response + let message = self.google_response_to_message(unescape_json_values(&response))?; + let usage = self.get_usage(&response)?; + let model = match response.get("modelVersion") { + Some(model_version) => model_version.as_str().unwrap_or_default().to_string(), + None => self.config.model.model_name.clone(), + }; + let provider_usage = ProviderUsage::new(model, usage, None); + Ok((message, provider_usage)) + } +} \ No newline at end of file diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 8b6a2748c9b7..d00530f6e03f 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -69,7 +69,7 @@ impl OpenAiProvider { let response = self .client .post(&url) - .header("Authorization", format!("Bearer {}", self.config.api_key)) + // .header("Authorization", format!("Bearer {}", self.config.api_key)) .json(&payload) .send() .await?; diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index 0a7a5d1127cc..0a6e2dac779e 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -65,6 +65,7 @@ impl TokenCounter { fn model_to_tokenizer_key(model_name: Option<&str>) -> &str { let model_name = model_name.unwrap_or("gpt-4o").to_lowercase(); + // Lifei: TODO: add llamas to the list if model_name.contains("claude") { CLAUDE_TOKENIZER_KEY } else if model_name.contains("qwen") { From eef6f3c80194c7b3be57ac061933298fce367840 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 18 Dec 2024 14:38:36 +1100 Subject: [PATCH 02/22] extract handle_response --- crates/goose/src/providers/google.rs | 17 ++-------------- crates/goose/src/providers/ollama.rs | 25 ++++++----------------- crates/goose/src/providers/openai.rs | 30 +++++++++------------------- crates/goose/src/providers/utils.rs | 18 ++++++++++++++++- 4 files changed, 34 insertions(+), 56 deletions(-) diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index e5ab052de328..a1df3b71bb41 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -2,9 +2,7 @@ use crate::errors::AgentError; use crate::message::{Message, MessageContent}; use crate::providers::base::{Provider, ProviderUsage, Usage}; use crate::providers::configs::{GoogleProviderConfig, ModelConfig, ProviderModelConfig}; -use crate::providers::utils::{ - is_valid_function_name, sanitize_function_name, unescape_json_values, -}; +use crate::providers::utils::{handle_response, is_valid_function_name, sanitize_function_name, unescape_json_values}; use anyhow::anyhow; use async_trait::async_trait; use mcp_core::{Content, Role, Tool, ToolCall}; @@ -66,18 +64,7 @@ impl GoogleProvider { .send() .await?; - match response.status() { - StatusCode::OK => Ok(response.json().await?), - status if status == StatusCode::TOO_MANY_REQUESTS || status.as_u16() >= 500 => { - // Implement retry logic here if needed - Err(anyhow!("Server error: {}", status)) - } - _ => Err(anyhow!( - "Request failed: {}\nPayload: {}", - response.status(), - payload - )), - } + handle_response(payload, response).await? } fn messages_to_google_spec(&self, messages: &[Message]) -> Vec { diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index feee301cf16f..180789ee2664 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -1,9 +1,6 @@ use super::base::{Provider, ProviderUsage, Usage}; use super::configs::{ModelConfig, OllamaProviderConfig, ProviderModelConfig}; -use super::utils::{ - get_model, messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, - ImageFormat, -}; +use super::utils::{get_model, handle_response, messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, ImageFormat}; use crate::message::Message; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -65,22 +62,16 @@ impl OllamaProvider { let response = self.client.post(&url).json(&payload).send().await?; - match response.status() { - StatusCode::OK => Ok(response.json().await?), - status if status == StatusCode::TOO_MANY_REQUESTS || status.as_u16() >= 500 => { - Err(anyhow!("Server error: {}", status)) - } - _ => Err(anyhow!( - "Request failed: {}\nPayload: {}", - response.status(), - payload - )), - } + handle_response(payload, response).await? } } #[async_trait] impl Provider for OllamaProvider { + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() + } + async fn complete( &self, system: &str, @@ -132,10 +123,6 @@ impl Provider for OllamaProvider { Ok((message, ProviderUsage::new(model, usage, cost))) } - - fn get_model_config(&self) -> &ModelConfig { - self.config.model_config() - } } #[cfg(test)] diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index d00530f6e03f..a42982e165fe 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -1,7 +1,6 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; -use reqwest::Client; -use reqwest::StatusCode; +use reqwest::{Client}; use serde_json::{json, Value}; use std::time::Duration; @@ -11,7 +10,7 @@ use super::configs::OpenAiProviderConfig; use super::configs::{ModelConfig, ProviderModelConfig}; use super::model_pricing::cost; use super::model_pricing::model_pricing_for; -use super::utils::get_model; +use super::utils::{get_model, handle_response}; use super::utils::{ check_openai_context_length_error, messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, ImageFormat, @@ -69,28 +68,21 @@ impl OpenAiProvider { let response = self .client .post(&url) - // .header("Authorization", format!("Bearer {}", self.config.api_key)) + .header("Authorization", format!("Bearer {}", self.config.api_key)) .json(&payload) .send() .await?; - match response.status() { - StatusCode::OK => Ok(response.json().await?), - status if status == StatusCode::TOO_MANY_REQUESTS || status.as_u16() >= 500 => { - // Implement retry logic here if needed - Err(anyhow!("Server error: {}", status)) - } - _ => Err(anyhow!( - "Request failed: {}\nPayload: {}", - response.status(), - payload - )), - } + handle_response(payload, response).await? } } #[async_trait] impl Provider for OpenAiProvider { + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() + } + async fn complete( &self, system: &str, @@ -160,10 +152,6 @@ impl Provider for OpenAiProvider { Ok((message, ProviderUsage::new(model, usage, cost))) } - - fn get_model_config(&self) -> &ModelConfig { - self.config.model_config() - } } #[cfg(test)] @@ -300,7 +288,7 @@ mod tests { // Assert the response if let MessageContent::ToolRequest(tool_request) = &message.content[0] { - let tool_call = tool_request.tool_call.as_ref().unwrap(); + let tool_call = tool_request.tool_call.as_ref()?; assert_eq!(tool_call.name, "get_weather"); assert_eq!( tool_call.arguments, diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index f3bd5d8ee516..ba39ced34399 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -1,5 +1,6 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Error, Result}; use regex::Regex; +use reqwest::{Response, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::{json, Map, Value}; @@ -234,6 +235,21 @@ pub fn openai_response_to_message(response: Value) -> Result { }) } +pub async fn handle_response(payload: Value, response: Response) -> Result, Error> { + Ok(match response.status() { + StatusCode::OK => Ok(response.json().await?), + status if status == StatusCode::TOO_MANY_REQUESTS || status.as_u16() >= 500 => { + // Implement retry logic here if needed + Err(anyhow!("Server error: {}", status)) + } + _ => Err(anyhow!( + "Request failed: {}\nPayload: {}", + response.status(), + payload + )), + }) +} + pub fn sanitize_function_name(name: &str) -> String { let re = Regex::new(r"[^a-zA-Z0-9_-]").unwrap(); re.replace_all(name, "_").to_string() From 4715cb9e7e7fbe273b1db063738337cfdd45f781 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 18 Dec 2024 15:04:02 +1100 Subject: [PATCH 03/22] extract more functions --- crates/goose-cli/src/profile.rs | 2 +- crates/goose/src/providers.rs | 2 +- crates/goose/src/providers/factory.rs | 4 +- crates/goose/src/providers/google.rs | 4 +- crates/goose/src/providers/groq.rs | 85 +++++++++++++-------------- crates/goose/src/providers/ollama.rs | 65 ++------------------ crates/goose/src/providers/openai.rs | 72 ++--------------------- crates/goose/src/providers/utils.rs | 82 ++++++++++++++++++++++++-- 8 files changed, 135 insertions(+), 181 deletions(-) diff --git a/crates/goose-cli/src/profile.rs b/crates/goose-cli/src/profile.rs index c959ecb7945f..e96ad21802ed 100644 --- a/crates/goose-cli/src/profile.rs +++ b/crates/goose-cli/src/profile.rs @@ -2,7 +2,7 @@ use anyhow::Result; use goose::key_manager::{get_keyring_secret, KeyRetrievalStrategy}; use goose::providers::configs::{ AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, GoogleProviderConfig, - ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig, GroqProviderConfig, + GroqProviderConfig, ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig, }; use serde::{Deserialize, Serialize}; use std::collections::HashMap; diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index f0e504793ce9..64eca59ed9dd 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -10,6 +10,6 @@ pub mod openai; pub mod utils; pub mod google; +pub mod groq; #[cfg(test)] pub mod mock; -pub mod groq; diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index 68780905d143..eaf23fc16b8b 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -3,9 +3,9 @@ use super::{ databricks::DatabricksProvider, google::GoogleProvider, ollama::OllamaProvider, openai::OpenAiProvider, }; +use crate::providers::groq::GroqProvider; use anyhow::Result; use strum_macros::EnumIter; -use crate::providers::groq::GroqProvider; #[derive(EnumIter, Debug)] pub enum ProviderType { @@ -14,7 +14,7 @@ pub enum ProviderType { Ollama, Anthropic, Google, - Groq + Groq, } pub fn get_provider(config: ProviderConfig) -> Result> { diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index a1df3b71bb41..aafe54535106 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -2,7 +2,9 @@ use crate::errors::AgentError; use crate::message::{Message, MessageContent}; use crate::providers::base::{Provider, ProviderUsage, Usage}; use crate::providers::configs::{GoogleProviderConfig, ModelConfig, ProviderModelConfig}; -use crate::providers::utils::{handle_response, is_valid_function_name, sanitize_function_name, unescape_json_values}; +use crate::providers::utils::{ + handle_response, is_valid_function_name, sanitize_function_name, unescape_json_values, +}; use anyhow::anyhow; use async_trait::async_trait; use mcp_core::{Content, Role, Tool, ToolCall}; diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index 72b25d5cd422..efc4153c40db 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -1,13 +1,17 @@ -use std::time::Duration; +use crate::message::Message; +use crate::providers::base::{Provider, ProviderUsage, Usage}; +use crate::providers::configs::{GroqProviderConfig, ModelConfig, ProviderModelConfig}; +use crate::providers::google::GoogleProvider; +use crate::providers::utils::{ + create_openai_request_payload, get_model, get_openai_usage, handle_response, + openai_response_to_message, unescape_json_values, +}; +use anyhow::anyhow; use async_trait::async_trait; +use mcp_core::Tool; use reqwest::Client; use serde_json::{json, Map, Value}; -use mcp_core::Tool; -use crate::message::Message; -use crate::providers::base::{Provider, ProviderUsage}; -use crate::providers::configs::{GroqProviderConfig, ModelConfig}; -use crate::providers::google::GoogleProvider; -use crate::providers::utils::unescape_json_values; +use std::time::Duration; pub const GROQ_API_HOST: &str = "https://api.groq.com"; pub const GROQ_DEFAULT_MODEL: &str = "llama3-70b-8192"; @@ -25,6 +29,27 @@ impl GroqProvider { Ok(Self { client, config }) } + + fn get_usage(data: &Value) -> anyhow::Result { + get_openai_usage(data) + } + + async fn post(&self, payload: Value) -> anyhow::Result { + let url = format!( + "{}/openai/v1/chat/completions", + self.config.host.trim_end_matches('/') + ); + + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", self.config.api_key)) + .json(&payload) + .send() + .await?; + + handle_response(payload, response).await? + } } #[async_trait] @@ -39,42 +64,14 @@ impl Provider for GroqProvider { messages: &[Message], tools: &[Tool], ) -> anyhow::Result<(Message, ProviderUsage)> { - let mut payload = Map::new(); - payload.insert( - "system_instruction".to_string(), - json!({"parts": [{"text": system}]}), - ); - payload.insert( - "contents".to_string(), - json!(self.messages_to_google_spec(&messages)), - ); - if !tools.is_empty() { - payload.insert( - "tools".to_string(), - json!({"functionDeclarations": self.tools_to_google_spec(&tools)}), - ); - } - let mut generation_config = Map::new(); - if let Some(temp) = self.config.model.temperature { - generation_config.insert("temperature".to_string(), json!(temp)); - } - if let Some(tokens) = self.config.model.max_tokens { - generation_config.insert("maxOutputTokens".to_string(), json!(tokens)); - } - if !generation_config.is_empty() { - payload.insert("generationConfig".to_string(), json!(generation_config)); - } + let payload = create_openai_request_payload(&self.config.model, system, messages, tools)?; + + let response = self.post(payload).await?; - // Make request - let response = self.post(Value::Object(payload)).await?; - // Parse response - let message = self.google_response_to_message(unescape_json_values(&response))?; - let usage = self.get_usage(&response)?; - let model = match response.get("modelVersion") { - Some(model_version) => model_version.as_str().unwrap_or_default().to_string(), - None => self.config.model.model_name.clone(), - }; - let provider_usage = ProviderUsage::new(model, usage, None); - Ok((message, provider_usage)) + let message = openai_response_to_message(response.clone())?; + let usage = Self::get_usage(&response)?; + let model = get_model(&response); + + Ok((message, ProviderUsage::new(model, usage, None))) } -} \ No newline at end of file +} diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 180789ee2664..9fed775a3b7f 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -1,6 +1,9 @@ use super::base::{Provider, ProviderUsage, Usage}; use super::configs::{ModelConfig, OllamaProviderConfig, ProviderModelConfig}; -use super::utils::{get_model, handle_response, messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, ImageFormat}; +use super::utils::{ + create_openai_request_payload, get_model, get_openai_usage, handle_response, + messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, ImageFormat, +}; use crate::message::Message; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -28,30 +31,7 @@ impl OllamaProvider { } fn get_usage(data: &Value) -> Result { - let usage = data - .get("usage") - .ok_or_else(|| anyhow!("No usage data in response"))?; - - let input_tokens = usage - .get("prompt_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let output_tokens = usage - .get("completion_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let total_tokens = usage - .get("total_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32) - .or_else(|| match (input_tokens, output_tokens) { - (Some(input), Some(output)) => Some(input + output), - _ => None, - }); - - Ok(Usage::new(input_tokens, output_tokens, total_tokens)) + get_openai_usage(data) } async fn post(&self, payload: Value) -> Result { @@ -78,40 +58,7 @@ impl Provider for OllamaProvider { messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage)> { - let system_message = json!({ - "role": "system", - "content": system - }); - - let messages_spec = messages_to_openai_spec(messages, &ImageFormat::OpenAi); - let tools_spec = tools_to_openai_spec(tools)?; - - let mut messages_array = vec![system_message]; - messages_array.extend(messages_spec); - - let mut payload = json!({ - "model": self.config.model.model_name, - "messages": messages_array - }); - - if !tools_spec.is_empty() { - payload - .as_object_mut() - .unwrap() - .insert("tools".to_string(), json!(tools_spec)); - } - if let Some(temp) = self.config.model.temperature { - payload - .as_object_mut() - .unwrap() - .insert("temperature".to_string(), json!(temp)); - } - if let Some(tokens) = self.config.model.max_tokens { - payload - .as_object_mut() - .unwrap() - .insert("max_tokens".to_string(), json!(tokens)); - } + let payload = create_openai_request_payload(&self.config.model, system, messages, tools)?; let response = self.post(payload).await?; diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index a42982e165fe..a8500350260a 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -1,6 +1,6 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; -use reqwest::{Client}; +use reqwest::Client; use serde_json::{json, Value}; use std::time::Duration; @@ -10,11 +10,11 @@ use super::configs::OpenAiProviderConfig; use super::configs::{ModelConfig, ProviderModelConfig}; use super::model_pricing::cost; use super::model_pricing::model_pricing_for; -use super::utils::{get_model, handle_response}; use super::utils::{ check_openai_context_length_error, messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, ImageFormat, }; +use super::utils::{create_openai_request_payload, get_model, get_openai_usage, handle_response}; use crate::message::Message; use mcp_core::tool::Tool; @@ -33,30 +33,7 @@ impl OpenAiProvider { } fn get_usage(data: &Value) -> Result { - let usage = data - .get("usage") - .ok_or_else(|| anyhow!("No usage data in response"))?; - - let input_tokens = usage - .get("prompt_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let output_tokens = usage - .get("completion_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let total_tokens = usage - .get("total_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32) - .or_else(|| match (input_tokens, output_tokens) { - (Some(input), Some(output)) => Some(input + output), - _ => None, - }); - - Ok(Usage::new(input_tokens, output_tokens, total_tokens)) + get_openai_usage(data) } async fn post(&self, payload: Value) -> Result { @@ -90,48 +67,7 @@ impl Provider for OpenAiProvider { tools: &[Tool], ) -> Result<(Message, ProviderUsage)> { // Not checking for o1 model here since system message is not supported by o1 - let system_message = json!({ - "role": "system", - "content": system - }); - - // Convert messages and tools to OpenAI format - let messages_spec = messages_to_openai_spec(messages, &ImageFormat::OpenAi); - let tools_spec = if !tools.is_empty() { - tools_to_openai_spec(tools)? - } else { - vec![] - }; - - // Build payload - // create messages array with system message first - let mut messages_array = vec![system_message]; - messages_array.extend(messages_spec); - - let mut payload = json!({ - "model": self.config.model.model_name, - "messages": messages_array - }); - - // Add optional parameters - if !tools_spec.is_empty() { - payload - .as_object_mut() - .unwrap() - .insert("tools".to_string(), json!(tools_spec)); - } - if let Some(temp) = self.config.model.temperature { - payload - .as_object_mut() - .unwrap() - .insert("temperature".to_string(), json!(temp)); - } - if let Some(tokens) = self.config.model.max_tokens { - payload - .as_object_mut() - .unwrap() - .insert("max_tokens".to_string(), json!(tokens)); - } + let payload = create_openai_request_payload(&self.config.model, system, messages, tools)?; // Make request let response = self.post(payload).await?; diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index ba39ced34399..423d429ef0fc 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -6,6 +6,8 @@ use serde_json::{json, Map, Value}; use crate::errors::AgentError; use crate::message::{Message, MessageContent}; +use crate::providers::base::Usage; +use crate::providers::configs::ModelConfig; use mcp_core::content::{Content, ImageContent}; use mcp_core::role::Role; use mcp_core::tool::{Tool, ToolCall}; @@ -243,13 +245,83 @@ pub async fn handle_response(payload: Value, response: Response) -> Result Err(anyhow!( - "Request failed: {}\nPayload: {}", - response.status(), - payload - )), + "Request failed: {}\nPayload: {}", + response.status(), + payload + )), }) } +pub fn get_openai_usage(data: &Value) -> Result { + let usage = data + .get("usage") + .ok_or_else(|| anyhow!("No usage data in response"))?; + + let input_tokens = usage + .get("prompt_tokens") + .and_then(|v| v.as_i64()) + .map(|v| v as i32); + + let output_tokens = usage + .get("completion_tokens") + .and_then(|v| v.as_i64()) + .map(|v| v as i32); + + let total_tokens = usage + .get("total_tokens") + .and_then(|v| v.as_i64()) + .map(|v| v as i32) + .or_else(|| match (input_tokens, output_tokens) { + (Some(input), Some(output)) => Some(input + output), + _ => None, + }); + + Ok(Usage::new(input_tokens, output_tokens, total_tokens)) +} + +pub fn create_openai_request_payload( + model_config: &ModelConfig, + system: &str, + messages: &[Message], + tools: &[Tool], +) -> Result { + let system_message = json!({ + "role": "system", + "content": system + }); + + let messages_spec = messages_to_openai_spec(messages, &ImageFormat::OpenAi); + let tools_spec = tools_to_openai_spec(tools)?; + + let mut messages_array = vec![system_message]; + messages_array.extend(messages_spec); + + let mut payload = json!({ + "model": model_config.model_name, + "messages": messages_array + }); + + if !tools_spec.is_empty() { + payload + .as_object_mut() + .unwrap() + .insert("tools".to_string(), json!(tools_spec)); + } + if let Some(temp) = model_config.temperature { + payload + .as_object_mut() + .unwrap() + .insert("temperature".to_string(), json!(temp)); + } + if let Some(tokens) = model_config.max_tokens { + payload + .as_object_mut() + .unwrap() + .insert("max_tokens".to_string(), json!(tokens)); + } + Ok(payload) +} + pub fn sanitize_function_name(name: &str) -> String { let re = Regex::new(r"[^a-zA-Z0-9_-]").unwrap(); re.replace_all(name, "_").to_string() @@ -533,7 +605,7 @@ mod tests { assert_eq!(message.content.len(), 1); if let MessageContent::ToolRequest(request) = &message.content[0] { - let tool_call = request.tool_call.as_ref().unwrap(); + let tool_call = request.tool_call.as_ref()?; assert_eq!(tool_call.name, "example_fn"); assert_eq!(tool_call.arguments, json!({"param": "value"})); } else { From 07147ebca525463a41c3fcaf64f0f0de687f96fb Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 18 Dec 2024 15:04:41 +1100 Subject: [PATCH 04/22] fixed more format --- crates/goose/src/providers/google.rs | 3 +-- crates/goose/src/providers/groq.rs | 6 ++---- crates/goose/src/providers/ollama.rs | 8 +++----- crates/goose/src/providers/openai.rs | 5 ++--- 4 files changed, 8 insertions(+), 14 deletions(-) diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index aafe54535106..3ed83b2722f9 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -5,10 +5,9 @@ use crate::providers::configs::{GoogleProviderConfig, ModelConfig, ProviderModel use crate::providers::utils::{ handle_response, is_valid_function_name, sanitize_function_name, unescape_json_values, }; -use anyhow::anyhow; use async_trait::async_trait; use mcp_core::{Content, Role, Tool, ToolCall}; -use reqwest::{Client, StatusCode}; +use reqwest::Client; use serde_json::{json, Map, Value}; use std::time::Duration; diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index efc4153c40db..39fec5d327c8 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -1,16 +1,14 @@ use crate::message::Message; use crate::providers::base::{Provider, ProviderUsage, Usage}; use crate::providers::configs::{GroqProviderConfig, ModelConfig, ProviderModelConfig}; -use crate::providers::google::GoogleProvider; use crate::providers::utils::{ create_openai_request_payload, get_model, get_openai_usage, handle_response, - openai_response_to_message, unescape_json_values, + openai_response_to_message, }; -use anyhow::anyhow; use async_trait::async_trait; use mcp_core::Tool; use reqwest::Client; -use serde_json::{json, Map, Value}; +use serde_json::Value; use std::time::Duration; pub const GROQ_API_HOST: &str = "https://api.groq.com"; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 9fed775a3b7f..6fde4de8c6e9 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -1,16 +1,14 @@ use super::base::{Provider, ProviderUsage, Usage}; use super::configs::{ModelConfig, OllamaProviderConfig, ProviderModelConfig}; use super::utils::{ - create_openai_request_payload, get_model, get_openai_usage, handle_response, - messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, ImageFormat, + create_openai_request_payload, get_model, get_openai_usage, handle_response, openai_response_to_message, }; use crate::message::Message; -use anyhow::{anyhow, Result}; +use anyhow::Result; use async_trait::async_trait; use mcp_core::tool::Tool; use reqwest::Client; -use reqwest::StatusCode; -use serde_json::{json, Value}; +use serde_json::Value; use std::time::Duration; pub const OLLAMA_HOST: &str = "http://localhost:11434"; diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index a8500350260a..f31998af3df3 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use reqwest::Client; -use serde_json::{json, Value}; +use serde_json::Value; use std::time::Duration; use super::base::ProviderUsage; @@ -11,8 +11,7 @@ use super::configs::{ModelConfig, ProviderModelConfig}; use super::model_pricing::cost; use super::model_pricing::model_pricing_for; use super::utils::{ - check_openai_context_length_error, messages_to_openai_spec, openai_response_to_message, - tools_to_openai_spec, ImageFormat, + check_openai_context_length_error, openai_response_to_message, }; use super::utils::{create_openai_request_payload, get_model, get_openai_usage, handle_response}; use crate::message::Message; From 0a9170666f9a5399a91c266495dae136020154ac Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 18 Dec 2024 16:09:42 +1100 Subject: [PATCH 05/22] convert content to string to match api schema --- crates/goose-cli/src/profile.rs | 2 +- crates/goose/src/providers/groq.rs | 1 - crates/goose/src/providers/utils.rs | 13 +++++++++---- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/crates/goose-cli/src/profile.rs b/crates/goose-cli/src/profile.rs index e96ad21802ed..d62dfc5b868e 100644 --- a/crates/goose-cli/src/profile.rs +++ b/crates/goose-cli/src/profile.rs @@ -139,7 +139,7 @@ pub fn get_provider_config(provider_name: &str, profile: Profile) -> ProviderCon let api_key = get_keyring_secret("GROQ_API_KEY", KeyRetrievalStrategy::Both) .expect("GROQ_API_KEY not available in env or the keychain\nSet an env var or rerun `goose configure`"); - ProviderConfig::Google(GoogleProviderConfig { + ProviderConfig::Groq(GroqProviderConfig { host: "https://api.groq.com".to_string(), api_key, model: model_config, diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index 39fec5d327c8..2cfdeeb67057 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -45,7 +45,6 @@ impl GroqProvider { .json(&payload) .send() .await?; - handle_response(payload, response).await? } } diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index 423d429ef0fc..6bb46fd12580 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -23,7 +23,6 @@ pub enum ImageFormat { /// even though the message structure is otherwise following openai, the enum switches this pub fn messages_to_openai_spec(messages: &[Message], image_format: &ImageFormat) -> Vec { let mut messages_spec = Vec::new(); - for message in messages { let mut converted = json!({ "role": message.role @@ -99,14 +98,20 @@ pub fn messages_to_openai_spec(messages: &[Message], image_format: &ImageFormat) } } } - + let concatenated_content = tool_content + .iter() + .map(|content| match content { + Content::Text(text) => text.text.clone(), + _ => String::new(), + }) + .collect::>() + .join(" "); // First add the tool response with all content output.push(json!({ "role": "tool", - "content": tool_content, + "content": concatenated_content, "tool_call_id": response.id })); - // Then add any image messages that need to follow output.extend(image_messages); } From ab705fd4bb4aa77b5298b80f02e0d969e04ffe31 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 18 Dec 2024 17:31:06 +1100 Subject: [PATCH 06/22] moved open ai specific utils to a separate file --- crates/goose/src/providers.rs | 3 +- crates/goose/src/providers/databricks.rs | 4 +- crates/goose/src/providers/groq.rs | 6 +- crates/goose/src/providers/ollama.rs | 6 +- crates/goose/src/providers/openai.rs | 6 +- crates/goose/src/providers/openai_utils.rs | 560 +++++++++++++++++++++ crates/goose/src/providers/utils.rs | 556 +------------------- 7 files changed, 572 insertions(+), 569 deletions(-) create mode 100644 crates/goose/src/providers/openai_utils.rs diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index 64eca59ed9dd..9dc54e46d6fb 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -7,9 +7,10 @@ pub mod model_pricing; pub mod oauth; pub mod ollama; pub mod openai; +pub mod openai_utils; pub mod utils; pub mod google; pub mod groq; #[cfg(test)] -pub mod mock; +pub mod mock; \ No newline at end of file diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 460341d450e5..6b0ad1e17ea0 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -9,11 +9,11 @@ use super::configs::{DatabricksAuth, DatabricksProviderConfig, ModelConfig, Prov use super::model_pricing::{cost, model_pricing_for}; use super::oauth; use super::utils::{ - check_bedrock_context_length_error, check_openai_context_length_error, get_model, - messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, + check_bedrock_context_length_error, get_model }; use crate::message::Message; use mcp_core::tool::Tool; +use crate::providers::openai_utils::{check_openai_context_length_error, messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec}; pub struct DatabricksProvider { client: Client, diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index 2cfdeeb67057..4519363593d2 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -1,15 +1,13 @@ use crate::message::Message; use crate::providers::base::{Provider, ProviderUsage, Usage}; use crate::providers::configs::{GroqProviderConfig, ModelConfig, ProviderModelConfig}; -use crate::providers::utils::{ - create_openai_request_payload, get_model, get_openai_usage, handle_response, - openai_response_to_message, -}; +use crate::providers::utils::{get_model, handle_response}; use async_trait::async_trait; use mcp_core::Tool; use reqwest::Client; use serde_json::Value; use std::time::Duration; +use crate::providers::openai_utils::{create_openai_request_payload, get_openai_usage, openai_response_to_message}; pub const GROQ_API_HOST: &str = "https://api.groq.com"; pub const GROQ_DEFAULT_MODEL: &str = "llama3-70b-8192"; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 6fde4de8c6e9..d845eb7c666b 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -1,8 +1,6 @@ use super::base::{Provider, ProviderUsage, Usage}; use super::configs::{ModelConfig, OllamaProviderConfig, ProviderModelConfig}; -use super::utils::{ - create_openai_request_payload, get_model, get_openai_usage, handle_response, openai_response_to_message, -}; +use super::utils::{ get_model, handle_response, }; use crate::message::Message; use anyhow::Result; use async_trait::async_trait; @@ -10,6 +8,7 @@ use mcp_core::tool::Tool; use reqwest::Client; use serde_json::Value; use std::time::Duration; +use crate::providers::openai_utils::{create_openai_request_payload, get_openai_usage, openai_response_to_message}; pub const OLLAMA_HOST: &str = "http://localhost:11434"; pub const OLLAMA_MODEL: &str = "qwen2.5"; @@ -195,7 +194,6 @@ mod tests { let (message, usage) = provider .complete("You are a helpful assistant.", &messages, &[tool]) .await?; - // Assert the response if let MessageContent::ToolRequest(tool_request) = &message.content[0] { let tool_call = tool_request.tool_call.as_ref().unwrap(); diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index f31998af3df3..355a49b88390 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -10,12 +10,10 @@ use super::configs::OpenAiProviderConfig; use super::configs::{ModelConfig, ProviderModelConfig}; use super::model_pricing::cost; use super::model_pricing::model_pricing_for; -use super::utils::{ - check_openai_context_length_error, openai_response_to_message, -}; -use super::utils::{create_openai_request_payload, get_model, get_openai_usage, handle_response}; +use super::utils::{get_model, handle_response}; use crate::message::Message; use mcp_core::tool::Tool; +use crate::providers::openai_utils::{check_openai_context_length_error, create_openai_request_payload, get_openai_usage, openai_response_to_message}; pub struct OpenAiProvider { client: Client, diff --git a/crates/goose/src/providers/openai_utils.rs b/crates/goose/src/providers/openai_utils.rs new file mode 100644 index 000000000000..18cfd0347cd7 --- /dev/null +++ b/crates/goose/src/providers/openai_utils.rs @@ -0,0 +1,560 @@ +use anyhow::{anyhow, Error}; +use serde_json::{json, Value}; +use mcp_core::{Content, Role, Tool, ToolCall}; +use crate::errors::AgentError; +use crate::message::{Message, MessageContent}; +use crate::providers::base::Usage; +use crate::providers::configs::ModelConfig; +use crate::providers::utils::{convert_image, is_valid_function_name, sanitize_function_name, ContextLengthExceededError, ImageFormat}; + +/// Convert internal Message format to OpenAI's API message specification +/// some openai compatible endpoints use the anthropic image spec at the content level +/// even though the message structure is otherwise following openai, the enum switches this +pub fn messages_to_openai_spec(messages: &[Message], image_format: &ImageFormat) -> Vec { + let mut messages_spec = Vec::new(); + for message in messages { + let mut converted = json!({ + "role": message.role + }); + + let mut output = Vec::new(); + + for content in &message.content { + match content { + MessageContent::Text(text) => { + if !text.text.is_empty() { + converted["content"] = json!(text.text); + } + } + MessageContent::ToolRequest(request) => match &request.tool_call { + Ok(tool_call) => { + let sanitized_name = sanitize_function_name(&tool_call.name); + let tool_calls = converted + .as_object_mut() + .unwrap() + .entry("tool_calls") + .or_insert(json!([])); + + tool_calls.as_array_mut().unwrap().push(json!({ + "id": request.id, + "type": "function", + "function": { + "name": sanitized_name, + "arguments": tool_call.arguments.to_string(), + } + })); + } + Err(e) => { + output.push(json!({ + "role": "tool", + "content": format!("Error: {}", e), + "tool_call_id": request.id + })); + } + }, + MessageContent::ToolResponse(response) => { + match &response.tool_result { + Ok(contents) => { + // Send only contents with no audience or with Assistant in the audience + let abridged: Vec<_> = contents + .iter() + .filter(|content| { + content + .audience() + .is_none_or(|audience| audience.contains(&Role::Assistant)) + }) + .map(|content| content.unannotated()) + .collect(); + + // Process all content, replacing images with placeholder text + let mut tool_content = Vec::new(); + let mut image_messages = Vec::new(); + + for content in abridged { + match content { + Content::Image(image) => { + // Add placeholder text in the tool response + tool_content.push(Content::text("This tool result included an image that is uploaded in the next message.")); + + // Create a separate image message + image_messages.push(json!({ + "role": "user", + "content": [convert_image(&image, image_format)] + })); + } + _ => { + tool_content.push(content); + } + } + } + let concatenated_content = tool_content + .iter() + .map(|content| match content { + Content::Text(text) => text.text.clone(), + _ => String::new(), + }) + .collect::>() + .join(" "); + // First add the tool response with all content + output.push(json!({ + "role": "tool", + "content": concatenated_content, + "tool_call_id": response.id + })); + // Then add any image messages that need to follow + output.extend(image_messages); + } + Err(e) => { + // A tool result error is shown as output so the model can interpret the error message + output.push(json!({ + "role": "tool", + "content": format!("The tool call returned the following error:\n{}", e), + "tool_call_id": response.id + })); + } + } + } + MessageContent::Image(image) => { + // Handle direct image content + converted["content"] = json!([convert_image(image, image_format)]); + } + } + } + + if converted.get("content").is_some() || converted.get("tool_calls").is_some() { + output.insert(0, converted); + } + messages_spec.extend(output); + } + + messages_spec +} + +/// Convert internal Tool format to OpenAI's API tool specification +pub fn tools_to_openai_spec(tools: &[Tool]) -> anyhow::Result> { + let mut tool_names = std::collections::HashSet::new(); + let mut result = Vec::new(); + + for tool in tools { + if !tool_names.insert(&tool.name) { + return Err(anyhow!("Duplicate tool name: {}", tool.name)); + } + + result.push(json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema, + } + })); + } + + Ok(result) +} + +/// Convert OpenAI's API response to internal Message format +pub fn openai_response_to_message(response: Value) -> anyhow::Result { + let original = response["choices"][0]["message"].clone(); + let mut content = Vec::new(); + + if let Some(text) = original.get("content") { + if let Some(text_str) = text.as_str() { + content.push(MessageContent::text(text_str)); + } + } + + if let Some(tool_calls) = original.get("tool_calls") { + if let Some(tool_calls_array) = tool_calls.as_array() { + for tool_call in tool_calls_array { + let id = tool_call["id"].as_str().unwrap_or_default().to_string(); + let function_name = tool_call["function"]["name"] + .as_str() + .unwrap_or_default() + .to_string(); + let arguments = tool_call["function"]["arguments"] + .as_str() + .unwrap_or_default() + .to_string(); + + if !is_valid_function_name(&function_name) { + let error = AgentError::ToolNotFound(format!( + "The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+", + function_name + )); + content.push(MessageContent::tool_request(id, Err(error))); + } else { + match serde_json::from_str::(&arguments) { + Ok(params) => { + content.push(MessageContent::tool_request( + id, + Ok(ToolCall::new(&function_name, params)), + )); + } + Err(e) => { + let error = AgentError::InvalidParameters(format!( + "Could not interpret tool use parameters for id {}: {}", + id, e + )); + content.push(MessageContent::tool_request(id, Err(error))); + } + } + } + } + } + } + + Ok(Message { + role: Role::Assistant, + created: chrono::Utc::now().timestamp(), + content, + }) +} + +pub fn get_openai_usage(data: &Value) -> anyhow::Result { + let usage = data + .get("usage") + .ok_or_else(|| anyhow!("No usage data in response"))?; + + let input_tokens = usage + .get("prompt_tokens") + .and_then(|v| v.as_i64()) + .map(|v| v as i32); + + let output_tokens = usage + .get("completion_tokens") + .and_then(|v| v.as_i64()) + .map(|v| v as i32); + + let total_tokens = usage + .get("total_tokens") + .and_then(|v| v.as_i64()) + .map(|v| v as i32) + .or_else(|| match (input_tokens, output_tokens) { + (Some(input), Some(output)) => Some(input + output), + _ => None, + }); + + Ok(Usage::new(input_tokens, output_tokens, total_tokens)) +} + +pub fn create_openai_request_payload( + model_config: &ModelConfig, + system: &str, + messages: &[Message], + tools: &[Tool], +) -> anyhow::Result { + let system_message = json!({ + "role": "system", + "content": system + }); + + let messages_spec = messages_to_openai_spec(messages, &ImageFormat::OpenAi); + let tools_spec = tools_to_openai_spec(tools)?; + + let mut messages_array = vec![system_message]; + messages_array.extend(messages_spec); + + let mut payload = json!({ + "model": model_config.model_name, + "messages": messages_array + }); + + if !tools_spec.is_empty() { + payload + .as_object_mut() + .unwrap() + .insert("tools".to_string(), json!(tools_spec)); + } + if let Some(temp) = model_config.temperature { + payload + .as_object_mut() + .unwrap() + .insert("temperature".to_string(), json!(temp)); + } + if let Some(tokens) = model_config.max_tokens { + payload + .as_object_mut() + .unwrap() + .insert("max_tokens".to_string(), json!(tokens)); + } + Ok(payload) +} + +pub fn check_openai_context_length_error(error: &Value) -> Option { + let code = error.get("code")?.as_str()?; + if code == "context_length_exceeded" || code == "string_above_max_length" { + let message = error + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("Unknown error") + .to_string(); + Some(ContextLengthExceededError(message)) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + use mcp_core::content::Content; + use serde_json::json; + + const OPENAI_TOOL_USE_RESPONSE: &str = r#"{ + "choices": [{ + "role": "assistant", + "message": { + "tool_calls": [{ + "id": "1", + "function": { + "name": "example_fn", + "arguments": "{\"param\": \"value\"}" + } + }] + } + }], + "usage": { + "input_tokens": 10, + "output_tokens": 25, + "total_tokens": 35 + } + }"#; + + #[test] + fn test_messages_to_openai_spec() -> anyhow::Result<()> { + let message = Message::user().with_text("Hello"); + let spec = messages_to_openai_spec(&[message], &ImageFormat::OpenAi); + + assert_eq!(spec.len(), 1); + assert_eq!(spec[0]["role"], "user"); + assert_eq!(spec[0]["content"], "Hello"); + Ok(()) + } + + #[test] + fn test_tools_to_openai_spec() -> anyhow::Result<()> { + let tool = Tool::new( + "test_tool", + "A test tool", + json!({ + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "Test parameter" + } + }, + "required": ["input"] + }), + ); + + let spec = tools_to_openai_spec(&[tool])?; + + assert_eq!(spec.len(), 1); + assert_eq!(spec[0]["type"], "function"); + assert_eq!(spec[0]["function"]["name"], "test_tool"); + Ok(()) + } + + #[test] + fn test_messages_to_openai_spec_complex() -> anyhow::Result<()> { + let mut messages = vec![ + Message::assistant().with_text("Hello!"), + Message::user().with_text("How are you?"), + Message::assistant().with_tool_request( + "tool1", + Ok(ToolCall::new("example", json!({"param1": "value1"}))), + ), + ]; + + // Get the ID from the tool request to use in the response + let tool_id = if let MessageContent::ToolRequest(request) = &messages[2].content[0] { + request.id.clone() + } else { + panic!("should be tool request"); + }; + + messages + .push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]))); + + let spec = messages_to_openai_spec(&messages, &ImageFormat::OpenAi); + + assert_eq!(spec.len(), 4); + assert_eq!(spec[0]["role"], "assistant"); + assert_eq!(spec[0]["content"], "Hello!"); + assert_eq!(spec[1]["role"], "user"); + assert_eq!(spec[1]["content"], "How are you?"); + assert_eq!(spec[2]["role"], "assistant"); + assert!(spec[2]["tool_calls"].is_array()); + assert_eq!(spec[3]["role"], "tool"); + assert_eq!( + spec[3]["content"], + json!([{"text": "Result", "type": "text"}]) + ); + assert_eq!(spec[3]["tool_call_id"], spec[2]["tool_calls"][0]["id"]); + + Ok(()) + } + + #[test] + fn test_tools_to_openai_spec_duplicate() -> anyhow::Result<()> { + let tool1 = Tool::new( + "test_tool", + "Test tool", + json!({ + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "Test parameter" + } + }, + "required": ["input"] + }), + ); + + let tool2 = Tool::new( + "test_tool", + "Test tool", + json!({ + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "Test parameter" + } + }, + "required": ["input"] + }), + ); + + let result = tools_to_openai_spec(&[tool1, tool2]); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Duplicate tool name")); + + Ok(()) + } + + #[test] + fn test_tools_to_openai_spec_empty() -> anyhow::Result<()> { + let spec = tools_to_openai_spec(&[])?; + assert!(spec.is_empty()); + Ok(()) + } + + #[test] + fn test_openai_response_to_message_text() -> anyhow::Result<()> { + let response = json!({ + "choices": [{ + "role": "assistant", + "message": { + "content": "Hello from John Cena!" + } + }], + "usage": { + "input_tokens": 10, + "output_tokens": 25, + "total_tokens": 35 + } + }); + + let message = openai_response_to_message(response)?; + assert_eq!(message.content.len(), 1); + if let MessageContent::Text(text) = &message.content[0] { + assert_eq!(text.text, "Hello from John Cena!"); + } else { + panic!("Expected Text content"); + } + assert!(matches!(message.role, Role::Assistant)); + + Ok(()) + } + + #[test] + fn test_openai_response_to_message_valid_toolrequest() -> anyhow::Result<()> { + let response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?; + let message = openai_response_to_message(response)?; + + assert_eq!(message.content.len(), 1); + if let MessageContent::ToolRequest(request) = &message.content[0] { + let tool_call = request.tool_call.as_ref()?; + assert_eq!(tool_call.name, "example_fn"); + assert_eq!(tool_call.arguments, json!({"param": "value"})); + } else { + panic!("Expected ToolRequest content"); + } + + Ok(()) + } + + #[test] + fn test_openai_response_to_message_invalid_func_name() -> anyhow::Result<()> { + let mut response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?; + response["choices"][0]["message"]["tool_calls"][0]["function"]["name"] = + json!("invalid fn"); + + let message = openai_response_to_message(response)?; + + if let MessageContent::ToolRequest(request) = &message.content[0] { + match &request.tool_call { + Err(AgentError::ToolNotFound(msg)) => { + assert!(msg.starts_with("The provided function name")); + } + _ => panic!("Expected ToolNotFound error"), + } + } else { + panic!("Expected ToolRequest content"); + } + + Ok(()) + } + + #[test] + fn test_openai_response_to_message_json_decode_error() -> anyhow::Result<()> { + let mut response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?; + response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] = + json!("invalid json {"); + + let message = openai_response_to_message(response)?; + + if let MessageContent::ToolRequest(request) = &message.content[0] { + match &request.tool_call { + Err(AgentError::InvalidParameters(msg)) => { + assert!(msg.starts_with("Could not interpret tool use parameters")); + } + _ => panic!("Expected InvalidParameters error"), + } + } else { + panic!("Expected ToolRequest content"); + } + + Ok(()) + } + + #[test] + fn test_check_openai_context_length_error() { + let error = json!({ + "code": "context_length_exceeded", + "message": "This message is too long" + }); + + let result = check_openai_context_length_error(&error); + assert!(result.is_some()); + assert_eq!( + result.unwrap().to_string(), + "Context length exceeded. Message: This message is too long" + ); + + let error = json!({ + "code": "other_error", + "message": "Some other error" + }); + + let result = check_openai_context_length_error(&error); + assert!(result.is_none()); + } +} diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index 6bb46fd12580..ee52a6059d17 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -4,13 +4,7 @@ use reqwest::{Response, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::{json, Map, Value}; -use crate::errors::AgentError; -use crate::message::{Message, MessageContent}; -use crate::providers::base::Usage; -use crate::providers::configs::ModelConfig; -use mcp_core::content::{Content, ImageContent}; -use mcp_core::role::Role; -use mcp_core::tool::{Tool, ToolCall}; +use mcp_core::content::ImageContent; #[derive(Debug, Copy, Clone, Serialize, Deserialize)] pub enum ImageFormat { @@ -18,129 +12,6 @@ pub enum ImageFormat { Anthropic, } -/// Convert internal Message format to OpenAI's API message specification -/// some openai compatible endpoints use the anthropic image spec at the content level -/// even though the message structure is otherwise following openai, the enum switches this -pub fn messages_to_openai_spec(messages: &[Message], image_format: &ImageFormat) -> Vec { - let mut messages_spec = Vec::new(); - for message in messages { - let mut converted = json!({ - "role": message.role - }); - - let mut output = Vec::new(); - - for content in &message.content { - match content { - MessageContent::Text(text) => { - if !text.text.is_empty() { - converted["content"] = json!(text.text); - } - } - MessageContent::ToolRequest(request) => match &request.tool_call { - Ok(tool_call) => { - let sanitized_name = sanitize_function_name(&tool_call.name); - let tool_calls = converted - .as_object_mut() - .unwrap() - .entry("tool_calls") - .or_insert(json!([])); - - tool_calls.as_array_mut().unwrap().push(json!({ - "id": request.id, - "type": "function", - "function": { - "name": sanitized_name, - "arguments": tool_call.arguments.to_string(), - } - })); - } - Err(e) => { - output.push(json!({ - "role": "tool", - "content": format!("Error: {}", e), - "tool_call_id": request.id - })); - } - }, - MessageContent::ToolResponse(response) => { - match &response.tool_result { - Ok(contents) => { - // Send only contents with no audience or with Assistant in the audience - let abridged: Vec<_> = contents - .iter() - .filter(|content| { - content - .audience() - .is_none_or(|audience| audience.contains(&Role::Assistant)) - }) - .map(|content| content.unannotated()) - .collect(); - - // Process all content, replacing images with placeholder text - let mut tool_content = Vec::new(); - let mut image_messages = Vec::new(); - - for content in abridged { - match content { - Content::Image(image) => { - // Add placeholder text in the tool response - tool_content.push(Content::text("This tool result included an image that is uploaded in the next message.")); - - // Create a separate image message - image_messages.push(json!({ - "role": "user", - "content": [convert_image(&image, image_format)] - })); - } - _ => { - tool_content.push(content); - } - } - } - let concatenated_content = tool_content - .iter() - .map(|content| match content { - Content::Text(text) => text.text.clone(), - _ => String::new(), - }) - .collect::>() - .join(" "); - // First add the tool response with all content - output.push(json!({ - "role": "tool", - "content": concatenated_content, - "tool_call_id": response.id - })); - // Then add any image messages that need to follow - output.extend(image_messages); - } - Err(e) => { - // A tool result error is shown as output so the model can interpret the error message - output.push(json!({ - "role": "tool", - "content": format!("The tool call returned the following error:\n{}", e), - "tool_call_id": response.id - })); - } - } - } - MessageContent::Image(image) => { - // Handle direct image content - converted["content"] = json!([convert_image(image, image_format)]); - } - } - } - - if converted.get("content").is_some() || converted.get("tool_calls").is_some() { - output.insert(0, converted); - } - messages_spec.extend(output); - } - - messages_spec -} - /// Convert an image content into an image json based on format pub fn convert_image(image: &ImageContent, image_format: &ImageFormat) -> Value { match image_format { @@ -161,87 +32,6 @@ pub fn convert_image(image: &ImageContent, image_format: &ImageFormat) -> Value } } -/// Convert internal Tool format to OpenAI's API tool specification -pub fn tools_to_openai_spec(tools: &[Tool]) -> Result> { - let mut tool_names = std::collections::HashSet::new(); - let mut result = Vec::new(); - - for tool in tools { - if !tool_names.insert(&tool.name) { - return Err(anyhow!("Duplicate tool name: {}", tool.name)); - } - - result.push(json!({ - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.input_schema, - } - })); - } - - Ok(result) -} - -/// Convert OpenAI's API response to internal Message format -pub fn openai_response_to_message(response: Value) -> Result { - let original = response["choices"][0]["message"].clone(); - let mut content = Vec::new(); - - if let Some(text) = original.get("content") { - if let Some(text_str) = text.as_str() { - content.push(MessageContent::text(text_str)); - } - } - - if let Some(tool_calls) = original.get("tool_calls") { - if let Some(tool_calls_array) = tool_calls.as_array() { - for tool_call in tool_calls_array { - let id = tool_call["id"].as_str().unwrap_or_default().to_string(); - let function_name = tool_call["function"]["name"] - .as_str() - .unwrap_or_default() - .to_string(); - let arguments = tool_call["function"]["arguments"] - .as_str() - .unwrap_or_default() - .to_string(); - - if !is_valid_function_name(&function_name) { - let error = AgentError::ToolNotFound(format!( - "The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+", - function_name - )); - content.push(MessageContent::tool_request(id, Err(error))); - } else { - match serde_json::from_str::(&arguments) { - Ok(params) => { - content.push(MessageContent::tool_request( - id, - Ok(ToolCall::new(&function_name, params)), - )); - } - Err(e) => { - let error = AgentError::InvalidParameters(format!( - "Could not interpret tool use parameters for id {}: {}", - id, e - )); - content.push(MessageContent::tool_request(id, Err(error))); - } - } - } - } - } - } - - Ok(Message { - role: Role::Assistant, - created: chrono::Utc::now().timestamp(), - content, - }) -} - pub async fn handle_response(payload: Value, response: Response) -> Result, Error> { Ok(match response.status() { StatusCode::OK => Ok(response.json().await?), @@ -257,76 +47,6 @@ pub async fn handle_response(payload: Value, response: Response) -> Result Result { - let usage = data - .get("usage") - .ok_or_else(|| anyhow!("No usage data in response"))?; - - let input_tokens = usage - .get("prompt_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let output_tokens = usage - .get("completion_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let total_tokens = usage - .get("total_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32) - .or_else(|| match (input_tokens, output_tokens) { - (Some(input), Some(output)) => Some(input + output), - _ => None, - }); - - Ok(Usage::new(input_tokens, output_tokens, total_tokens)) -} - -pub fn create_openai_request_payload( - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], -) -> Result { - let system_message = json!({ - "role": "system", - "content": system - }); - - let messages_spec = messages_to_openai_spec(messages, &ImageFormat::OpenAi); - let tools_spec = tools_to_openai_spec(tools)?; - - let mut messages_array = vec![system_message]; - messages_array.extend(messages_spec); - - let mut payload = json!({ - "model": model_config.model_name, - "messages": messages_array - }); - - if !tools_spec.is_empty() { - payload - .as_object_mut() - .unwrap() - .insert("tools".to_string(), json!(tools_spec)); - } - if let Some(temp) = model_config.temperature { - payload - .as_object_mut() - .unwrap() - .insert("temperature".to_string(), json!(temp)); - } - if let Some(tokens) = model_config.max_tokens { - payload - .as_object_mut() - .unwrap() - .insert("max_tokens".to_string(), json!(tokens)); - } - Ok(payload) -} - pub fn sanitize_function_name(name: &str) -> String { let re = Regex::new(r"[^a-zA-Z0-9_-]").unwrap(); re.replace_all(name, "_").to_string() @@ -339,21 +59,7 @@ pub fn is_valid_function_name(name: &str) -> bool { #[derive(Debug, thiserror::Error)] #[error("Context length exceeded. Message: {0}")] -pub struct ContextLengthExceededError(String); - -pub fn check_openai_context_length_error(error: &Value) -> Option { - let code = error.get("code")?.as_str()?; - if code == "context_length_exceeded" || code == "string_above_max_length" { - let message = error - .get("message") - .and_then(|m| m.as_str()) - .unwrap_or("Unknown error") - .to_string(); - Some(ContextLengthExceededError(message)) - } else { - None - } -} +pub struct ContextLengthExceededError(pub String); pub fn check_bedrock_context_length_error(error: &Value) -> Option { let external_message = error @@ -412,65 +118,8 @@ pub fn unescape_json_values(value: &Value) -> Value { #[cfg(test)] mod tests { use super::*; - use mcp_core::content::Content; use serde_json::json; - const OPENAI_TOOL_USE_RESPONSE: &str = r#"{ - "choices": [{ - "role": "assistant", - "message": { - "tool_calls": [{ - "id": "1", - "function": { - "name": "example_fn", - "arguments": "{\"param\": \"value\"}" - } - }] - } - }], - "usage": { - "input_tokens": 10, - "output_tokens": 25, - "total_tokens": 35 - } - }"#; - - #[test] - fn test_messages_to_openai_spec() -> Result<()> { - let message = Message::user().with_text("Hello"); - let spec = messages_to_openai_spec(&[message], &ImageFormat::OpenAi); - - assert_eq!(spec.len(), 1); - assert_eq!(spec[0]["role"], "user"); - assert_eq!(spec[0]["content"], "Hello"); - Ok(()) - } - - #[test] - fn test_tools_to_openai_spec() -> Result<()> { - let tool = Tool::new( - "test_tool", - "A test tool", - json!({ - "type": "object", - "properties": { - "input": { - "type": "string", - "description": "Test parameter" - } - }, - "required": ["input"] - }), - ); - - let spec = tools_to_openai_spec(&[tool])?; - - assert_eq!(spec.len(), 1); - assert_eq!(spec[0]["type"], "function"); - assert_eq!(spec[0]["function"]["name"], "test_tool"); - Ok(()) - } - #[test] fn test_sanitize_function_name() { assert_eq!(sanitize_function_name("hello-world"), "hello-world"); @@ -486,207 +135,6 @@ mod tests { assert!(!is_valid_function_name("hello@world")); } - #[test] - fn test_messages_to_openai_spec_complex() -> Result<()> { - let mut messages = vec![ - Message::assistant().with_text("Hello!"), - Message::user().with_text("How are you?"), - Message::assistant().with_tool_request( - "tool1", - Ok(ToolCall::new("example", json!({"param1": "value1"}))), - ), - ]; - - // Get the ID from the tool request to use in the response - let tool_id = if let MessageContent::ToolRequest(request) = &messages[2].content[0] { - request.id.clone() - } else { - panic!("should be tool request"); - }; - - messages - .push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]))); - - let spec = messages_to_openai_spec(&messages, &ImageFormat::OpenAi); - - assert_eq!(spec.len(), 4); - assert_eq!(spec[0]["role"], "assistant"); - assert_eq!(spec[0]["content"], "Hello!"); - assert_eq!(spec[1]["role"], "user"); - assert_eq!(spec[1]["content"], "How are you?"); - assert_eq!(spec[2]["role"], "assistant"); - assert!(spec[2]["tool_calls"].is_array()); - assert_eq!(spec[3]["role"], "tool"); - assert_eq!( - spec[3]["content"], - json!([{"text": "Result", "type": "text"}]) - ); - assert_eq!(spec[3]["tool_call_id"], spec[2]["tool_calls"][0]["id"]); - - Ok(()) - } - - #[test] - fn test_tools_to_openai_spec_duplicate() -> Result<()> { - let tool1 = Tool::new( - "test_tool", - "Test tool", - json!({ - "type": "object", - "properties": { - "input": { - "type": "string", - "description": "Test parameter" - } - }, - "required": ["input"] - }), - ); - - let tool2 = Tool::new( - "test_tool", - "Test tool", - json!({ - "type": "object", - "properties": { - "input": { - "type": "string", - "description": "Test parameter" - } - }, - "required": ["input"] - }), - ); - - let result = tools_to_openai_spec(&[tool1, tool2]); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Duplicate tool name")); - - Ok(()) - } - - #[test] - fn test_tools_to_openai_spec_empty() -> Result<()> { - let spec = tools_to_openai_spec(&[])?; - assert!(spec.is_empty()); - Ok(()) - } - - #[test] - fn test_openai_response_to_message_text() -> Result<()> { - let response = json!({ - "choices": [{ - "role": "assistant", - "message": { - "content": "Hello from John Cena!" - } - }], - "usage": { - "input_tokens": 10, - "output_tokens": 25, - "total_tokens": 35 - } - }); - - let message = openai_response_to_message(response)?; - assert_eq!(message.content.len(), 1); - if let MessageContent::Text(text) = &message.content[0] { - assert_eq!(text.text, "Hello from John Cena!"); - } else { - panic!("Expected Text content"); - } - assert!(matches!(message.role, Role::Assistant)); - - Ok(()) - } - - #[test] - fn test_openai_response_to_message_valid_toolrequest() -> Result<()> { - let response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?; - let message = openai_response_to_message(response)?; - - assert_eq!(message.content.len(), 1); - if let MessageContent::ToolRequest(request) = &message.content[0] { - let tool_call = request.tool_call.as_ref()?; - assert_eq!(tool_call.name, "example_fn"); - assert_eq!(tool_call.arguments, json!({"param": "value"})); - } else { - panic!("Expected ToolRequest content"); - } - - Ok(()) - } - - #[test] - fn test_openai_response_to_message_invalid_func_name() -> Result<()> { - let mut response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?; - response["choices"][0]["message"]["tool_calls"][0]["function"]["name"] = - json!("invalid fn"); - - let message = openai_response_to_message(response)?; - - if let MessageContent::ToolRequest(request) = &message.content[0] { - match &request.tool_call { - Err(AgentError::ToolNotFound(msg)) => { - assert!(msg.starts_with("The provided function name")); - } - _ => panic!("Expected ToolNotFound error"), - } - } else { - panic!("Expected ToolRequest content"); - } - - Ok(()) - } - - #[test] - fn test_openai_response_to_message_json_decode_error() -> Result<()> { - let mut response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?; - response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] = - json!("invalid json {"); - - let message = openai_response_to_message(response)?; - - if let MessageContent::ToolRequest(request) = &message.content[0] { - match &request.tool_call { - Err(AgentError::InvalidParameters(msg)) => { - assert!(msg.starts_with("Could not interpret tool use parameters")); - } - _ => panic!("Expected InvalidParameters error"), - } - } else { - panic!("Expected ToolRequest content"); - } - - Ok(()) - } - - #[test] - fn test_check_openai_context_length_error() { - let error = json!({ - "code": "context_length_exceeded", - "message": "This message is too long" - }); - - let result = check_openai_context_length_error(&error); - assert!(result.is_some()); - assert_eq!( - result.unwrap().to_string(), - "Context length exceeded. Message: This message is too long" - ); - - let error = json!({ - "code": "other_error", - "message": "Some other error" - }); - - let result = check_openai_context_length_error(&error); - assert!(result.is_none()); - } - #[test] fn test_check_bedrock_context_length_error() { let error = json!({ From 8ce3bdd2f659185020e639f208342d6a99f8e410 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 18 Dec 2024 17:38:53 +1100 Subject: [PATCH 07/22] fixed tests --- crates/goose/src/providers.rs | 2 +- crates/goose/src/providers/databricks.rs | 9 +++++---- crates/goose/src/providers/groq.rs | 4 +++- crates/goose/src/providers/ollama.rs | 6 ++++-- crates/goose/src/providers/openai.rs | 7 +++++-- crates/goose/src/providers/openai_utils.rs | 18 +++++++++--------- 6 files changed, 27 insertions(+), 19 deletions(-) diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index 9dc54e46d6fb..e7d739142712 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -13,4 +13,4 @@ pub mod utils; pub mod google; pub mod groq; #[cfg(test)] -pub mod mock; \ No newline at end of file +pub mod mock; diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 6b0ad1e17ea0..1491b6b4e129 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -8,12 +8,13 @@ use super::base::{Provider, ProviderUsage, Usage}; use super::configs::{DatabricksAuth, DatabricksProviderConfig, ModelConfig, ProviderModelConfig}; use super::model_pricing::{cost, model_pricing_for}; use super::oauth; -use super::utils::{ - check_bedrock_context_length_error, get_model -}; +use super::utils::{check_bedrock_context_length_error, get_model}; use crate::message::Message; +use crate::providers::openai_utils::{ + check_openai_context_length_error, messages_to_openai_spec, openai_response_to_message, + tools_to_openai_spec, +}; use mcp_core::tool::Tool; -use crate::providers::openai_utils::{check_openai_context_length_error, messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec}; pub struct DatabricksProvider { client: Client, diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index 4519363593d2..eceb7310c977 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -1,13 +1,15 @@ use crate::message::Message; use crate::providers::base::{Provider, ProviderUsage, Usage}; use crate::providers::configs::{GroqProviderConfig, ModelConfig, ProviderModelConfig}; +use crate::providers::openai_utils::{ + create_openai_request_payload, get_openai_usage, openai_response_to_message, +}; use crate::providers::utils::{get_model, handle_response}; use async_trait::async_trait; use mcp_core::Tool; use reqwest::Client; use serde_json::Value; use std::time::Duration; -use crate::providers::openai_utils::{create_openai_request_payload, get_openai_usage, openai_response_to_message}; pub const GROQ_API_HOST: &str = "https://api.groq.com"; pub const GROQ_DEFAULT_MODEL: &str = "llama3-70b-8192"; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index d845eb7c666b..833fc236d9b5 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -1,14 +1,16 @@ use super::base::{Provider, ProviderUsage, Usage}; use super::configs::{ModelConfig, OllamaProviderConfig, ProviderModelConfig}; -use super::utils::{ get_model, handle_response, }; +use super::utils::{get_model, handle_response}; use crate::message::Message; +use crate::providers::openai_utils::{ + create_openai_request_payload, get_openai_usage, openai_response_to_message, +}; use anyhow::Result; use async_trait::async_trait; use mcp_core::tool::Tool; use reqwest::Client; use serde_json::Value; use std::time::Duration; -use crate::providers::openai_utils::{create_openai_request_payload, get_openai_usage, openai_response_to_message}; pub const OLLAMA_HOST: &str = "http://localhost:11434"; pub const OLLAMA_MODEL: &str = "qwen2.5"; diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 355a49b88390..b1beedce3091 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -12,8 +12,11 @@ use super::model_pricing::cost; use super::model_pricing::model_pricing_for; use super::utils::{get_model, handle_response}; use crate::message::Message; +use crate::providers::openai_utils::{ + check_openai_context_length_error, create_openai_request_payload, get_openai_usage, + openai_response_to_message, +}; use mcp_core::tool::Tool; -use crate::providers::openai_utils::{check_openai_context_length_error, create_openai_request_payload, get_openai_usage, openai_response_to_message}; pub struct OpenAiProvider { client: Client, @@ -221,7 +224,7 @@ mod tests { // Assert the response if let MessageContent::ToolRequest(tool_request) = &message.content[0] { - let tool_call = tool_request.tool_call.as_ref()?; + let tool_call = tool_request.tool_call.as_ref().unwrap(); assert_eq!(tool_call.name, "get_weather"); assert_eq!( tool_call.arguments, diff --git a/crates/goose/src/providers/openai_utils.rs b/crates/goose/src/providers/openai_utils.rs index 18cfd0347cd7..2793ed07927b 100644 --- a/crates/goose/src/providers/openai_utils.rs +++ b/crates/goose/src/providers/openai_utils.rs @@ -1,11 +1,14 @@ -use anyhow::{anyhow, Error}; -use serde_json::{json, Value}; -use mcp_core::{Content, Role, Tool, ToolCall}; use crate::errors::AgentError; use crate::message::{Message, MessageContent}; use crate::providers::base::Usage; use crate::providers::configs::ModelConfig; -use crate::providers::utils::{convert_image, is_valid_function_name, sanitize_function_name, ContextLengthExceededError, ImageFormat}; +use crate::providers::utils::{ + convert_image, is_valid_function_name, sanitize_function_name, ContextLengthExceededError, + ImageFormat, +}; +use anyhow::{anyhow, Error}; +use mcp_core::{Content, Role, Tool, ToolCall}; +use serde_json::{json, Value}; /// Convert internal Message format to OpenAI's API message specification /// some openai compatible endpoints use the anthropic image spec at the content level @@ -388,10 +391,7 @@ mod tests { assert_eq!(spec[2]["role"], "assistant"); assert!(spec[2]["tool_calls"].is_array()); assert_eq!(spec[3]["role"], "tool"); - assert_eq!( - spec[3]["content"], - json!([{"text": "Result", "type": "text"}]) - ); + assert_eq!(spec[3]["content"], "Result"); assert_eq!(spec[3]["tool_call_id"], spec[2]["tool_calls"][0]["id"]); Ok(()) @@ -481,7 +481,7 @@ mod tests { assert_eq!(message.content.len(), 1); if let MessageContent::ToolRequest(request) = &message.content[0] { - let tool_call = request.tool_call.as_ref()?; + let tool_call = request.tool_call.as_ref().unwrap(); assert_eq!(tool_call.name, "example_fn"); assert_eq!(tool_call.arguments, json!({"param": "value"})); } else { From 726471d15dd44c3bee412a4f8ebf189381d2edbb Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 18 Dec 2024 17:50:56 +1100 Subject: [PATCH 08/22] fix format --- crates/goose-server/src/configuration.rs | 2 +- crates/goose-server/src/state.rs | 14 ++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index 684474072421..b3c74dffa02d 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -7,7 +7,7 @@ use goose::providers::{ OpenAiProviderConfig, ProviderConfig, }, factory::ProviderType, - google, ollama, groq, + google, groq, ollama, utils::ImageFormat, }; use serde::Deserialize; diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index 8973822b7078..8c07f82547bd 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -1,4 +1,5 @@ use anyhow::Result; +use goose::providers::configs::GroqProviderConfig; use goose::{ agent::Agent, developer::DeveloperSystem, @@ -7,7 +8,6 @@ use goose::{ }; use std::sync::Arc; use tokio::sync::Mutex; -use goose::providers::configs::GroqProviderConfig; /// Shared application state pub struct AppState { @@ -72,13 +72,11 @@ impl Clone for AppState { model: config.model.clone(), }) } - ProviderConfig::Groq(config) => { - ProviderConfig::Groq(GroqProviderConfig { - host: config.host.clone(), - api_key: config.api_key.clone(), - model: config.model.clone(), - }) - } + ProviderConfig::Groq(config) => ProviderConfig::Groq(GroqProviderConfig { + host: config.host.clone(), + api_key: config.api_key.clone(), + model: config.model.clone(), + }), }, agent: self.agent.clone(), secret_key: self.secret_key.clone(), From 38ecbdc45716793924a5122a5fa5ed7f8c816bb5 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 18 Dec 2024 20:15:43 +1100 Subject: [PATCH 09/22] added llama tokenizer --- crates/goose-cli/src/commands/configure.rs | 6 ++++-- crates/goose/build.rs | 1 + crates/goose/src/providers/groq.rs | 2 +- crates/goose/src/token_counter.rs | 5 ++++- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index da079d043656..c9636c7a3b4d 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -6,6 +6,8 @@ use console::style; use goose::key_manager::{get_keyring_secret, save_to_keyring, KeyRetrievalStrategy}; use goose::message::Message; use goose::providers::factory; +use goose::providers::google::GOOGLE_DEFAULT_MODEL; +use goose::providers::groq::GROQ_DEFAULT_MODEL; use goose::providers::ollama::OLLAMA_MODEL; use std::error::Error; @@ -159,8 +161,8 @@ pub fn get_recommended_model(provider_name: &str) -> &str { "databricks" => "claude-3-5-sonnet-2", "ollama" => OLLAMA_MODEL, "anthropic" => "claude-3-5-sonnet-2", - "google" => "gemini-1.5-flash", - "groq" => "llama3-70b-8192", + "google" => GOOGLE_DEFAULT_MODEL, + "groq" => GROQ_DEFAULT_MODEL, _ => panic!("Invalid provider name"), } } diff --git a/crates/goose/build.rs b/crates/goose/build.rs index 421ad3df5821..ccfb369848bb 100644 --- a/crates/goose/build.rs +++ b/crates/goose/build.rs @@ -8,6 +8,7 @@ const MODELS: &[&str] = &[ "Xenova/gemma-2-tokenizer", "Xenova/gpt-4o", "Qwen/Qwen2.5-Coder-32B-Instruct", + "Xenova/llama3-tokenizer", ]; #[tokio::main] diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index eceb7310c977..a550b44284c6 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -12,7 +12,7 @@ use serde_json::Value; use std::time::Duration; pub const GROQ_API_HOST: &str = "https://api.groq.com"; -pub const GROQ_DEFAULT_MODEL: &str = "llama3-70b-8192"; +pub const GROQ_DEFAULT_MODEL: &str = "llama-3.3-70b-versatile"; pub struct GroqProvider { client: Client, diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index 0a6e2dac779e..11546c106fa3 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -15,6 +15,7 @@ const GPT_4O_TOKENIZER_KEY: &str = "Xenova--gpt-4o"; const CLAUDE_TOKENIZER_KEY: &str = "Xenova--claude-tokenizer"; const GOOGLE_TOKENIZER_KEY: &str = "Xenova--gemma-2-tokenizer"; const QWEN_TOKENIZER_KEY: &str = "Qwen--Qwen2.5-Coder-32B-Instruct"; +const LLAMA_TOKENIZER_KEY: &str = "Xenova--llama3-tokenizer"; impl Default for TokenCounter { fn default() -> Self { @@ -53,6 +54,7 @@ impl TokenCounter { GPT_4O_TOKENIZER_KEY, CLAUDE_TOKENIZER_KEY, GOOGLE_TOKENIZER_KEY, + LLAMA_TOKENIZER_KEY, ] { counter.load_tokenizer(tokenizer_key); } @@ -65,13 +67,14 @@ impl TokenCounter { fn model_to_tokenizer_key(model_name: Option<&str>) -> &str { let model_name = model_name.unwrap_or("gpt-4o").to_lowercase(); - // Lifei: TODO: add llamas to the list if model_name.contains("claude") { CLAUDE_TOKENIZER_KEY } else if model_name.contains("qwen") { QWEN_TOKENIZER_KEY } else if model_name.contains("gemini") { GOOGLE_TOKENIZER_KEY + } else if model_name.contains("llama") { + LLAMA_TOKENIZER_KEY } else { // default GPT_4O_TOKENIZER_KEY From d9d184d054e17704dee8d1f3e8a35f63473356b6 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 18 Dec 2024 21:08:04 +1100 Subject: [PATCH 10/22] added missing qwen tokenizer --- crates/goose/src/token_counter.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index 11546c106fa3..345b0acc68bd 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -54,6 +54,7 @@ impl TokenCounter { GPT_4O_TOKENIZER_KEY, CLAUDE_TOKENIZER_KEY, GOOGLE_TOKENIZER_KEY, + QWEN_TOKENIZER_KEY, LLAMA_TOKENIZER_KEY, ] { counter.load_tokenizer(tokenizer_key); From c025802e083e4e86c3cc90098de00df4b45a4296 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 19 Dec 2024 11:06:31 +1100 Subject: [PATCH 11/22] clean up --- crates/goose-cli/src/commands/configure.rs | 9 ++++++--- crates/goose-server/src/configuration.rs | 3 ++- crates/goose/src/providers/anthropic.rs | 2 ++ crates/goose/src/providers/databricks.rs | 2 ++ crates/goose/src/providers/factory.rs | 5 ++--- crates/goose/src/providers/openai.rs | 2 ++ 6 files changed, 16 insertions(+), 7 deletions(-) diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index c9636c7a3b4d..99298f433f43 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -5,10 +5,13 @@ use cliclack::spinner; use console::style; use goose::key_manager::{get_keyring_secret, save_to_keyring, KeyRetrievalStrategy}; use goose::message::Message; +use goose::providers::anthropic::ANTHROPIC_DEFAULT_MODEL; +use goose::providers::databricks::DATABRICKS_DEFAULT_MODEL; use goose::providers::factory; use goose::providers::google::GOOGLE_DEFAULT_MODEL; use goose::providers::groq::GROQ_DEFAULT_MODEL; use goose::providers::ollama::OLLAMA_MODEL; +use goose::providers::openai::OPEN_AI_DEFAULT_MODEL; use std::error::Error; pub async fn handle_configure( @@ -157,10 +160,10 @@ pub async fn handle_configure( pub fn get_recommended_model(provider_name: &str) -> &str { match provider_name { - "openai" => "gpt-4o", - "databricks" => "claude-3-5-sonnet-2", + "openai" => OPEN_AI_DEFAULT_MODEL, + "databricks" => DATABRICKS_DEFAULT_MODEL, "ollama" => OLLAMA_MODEL, - "anthropic" => "claude-3-5-sonnet-2", + "anthropic" => ANTHROPIC_DEFAULT_MODEL, "google" => GOOGLE_DEFAULT_MODEL, "groq" => GROQ_DEFAULT_MODEL, _ => panic!("Invalid provider name"), diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index b3c74dffa02d..c6435db591f2 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -1,6 +1,7 @@ use crate::error::{to_env_var, ConfigError}; use config::{Config, Environment}; use goose::providers::configs::{GoogleProviderConfig, GroqProviderConfig}; +use goose::providers::openai::OPEN_AI_DEFAULT_MODEL; use goose::providers::{ configs::{ DatabricksAuth, DatabricksProviderConfig, ModelConfig, OllamaProviderConfig, @@ -265,7 +266,7 @@ fn default_port() -> u16 { } fn default_model() -> String { - "gpt-4o".to_string() + OPEN_AI_DEFAULT_MODEL.to_string() } fn default_openai_host() -> String { diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index c769eef4f4e5..c23d77df7e7f 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -17,6 +17,8 @@ use mcp_core::content::Content; use mcp_core::role::Role; use mcp_core::tool::{Tool, ToolCall}; +pub const ANTHROPIC_DEFAULT_MODEL: &str = "claude-3-5-sonnet-latest"; + pub struct AnthropicProvider { client: Client, config: AnthropicProviderConfig, diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 1491b6b4e129..b918eefe665b 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -16,6 +16,8 @@ use crate::providers::openai_utils::{ }; use mcp_core::tool::Tool; +pub const DATABRICKS_DEFAULT_MODEL: &str = "claude-3-5-sonnet-2"; + pub struct DatabricksProvider { client: Client, config: DatabricksProviderConfig, diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index eaf23fc16b8b..58ad7513bef2 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -1,9 +1,8 @@ use super::{ anthropic::AnthropicProvider, base::Provider, configs::ProviderConfig, - databricks::DatabricksProvider, google::GoogleProvider, ollama::OllamaProvider, - openai::OpenAiProvider, + databricks::DatabricksProvider, google::GoogleProvider, groq::GroqProvider, + ollama::OllamaProvider, openai::OpenAiProvider, }; -use crate::providers::groq::GroqProvider; use anyhow::Result; use strum_macros::EnumIter; diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index b1beedce3091..ba3470221725 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -18,6 +18,8 @@ use crate::providers::openai_utils::{ }; use mcp_core::tool::Tool; +pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o"; + pub struct OpenAiProvider { client: Client, config: OpenAiProviderConfig, From e3485bf676b409ca4bfe325b0ae6512312d59c9a Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 19 Dec 2024 11:51:59 +1100 Subject: [PATCH 12/22] refactored openai integration tests --- crates/goose/src/providers.rs | 2 + crates/goose/src/providers/mock.rs | 15 +-- crates/goose/src/providers/mock_server.rs | 94 +++++++++++++++++++ crates/goose/src/providers/openai.rs | 109 ++++++---------------- 4 files changed, 132 insertions(+), 88 deletions(-) create mode 100644 crates/goose/src/providers/mock_server.rs diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index e7d739142712..08e29029e2e7 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -14,3 +14,5 @@ pub mod google; pub mod groq; #[cfg(test)] pub mod mock; +#[cfg(test)] +mod mock_server; diff --git a/crates/goose/src/providers/mock.rs b/crates/goose/src/providers/mock.rs index 830c20601a82..11505d9d8fad 100644 --- a/crates/goose/src/providers/mock.rs +++ b/crates/goose/src/providers/mock.rs @@ -1,15 +1,16 @@ +use super::base::ProviderUsage; +use crate::message::Message; +use crate::providers::base::{Provider, Usage}; +use crate::providers::configs::ModelConfig; +use crate::providers::openai::OpenAiProvider; use anyhow::Result; use async_trait::async_trait; +use mcp_core::tool::Tool; use rust_decimal_macros::dec; +use serde_json::Value; use std::sync::Arc; use std::sync::Mutex; - -use crate::message::Message; -use crate::providers::base::{Provider, Usage}; -use crate::providers::configs::ModelConfig; -use mcp_core::tool::Tool; - -use super::base::ProviderUsage; +use wiremock::MockServer; /// A mock provider that returns pre-configured responses for testing pub struct MockProvider { diff --git a/crates/goose/src/providers/mock_server.rs b/crates/goose/src/providers/mock_server.rs new file mode 100644 index 000000000000..fbc4305b7af0 --- /dev/null +++ b/crates/goose/src/providers/mock_server.rs @@ -0,0 +1,94 @@ +use mcp_core::Tool; +use serde_json::{json, Value}; +use wiremock::matchers::{method, path}; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +pub const TEST_INPUT_TOKENS: i32 = 12; +pub const TEST_OUTPUT_TOKENS: i32 = 15; +pub const TEST_TOTAL_TOKENS: i32 = 27; +pub const TEST_TOOL_FUNCTION_NAME: &str = "get_weather"; +pub const TEST_TOOL_FUNCTION_ARGUMENTS: &str = "{\"location\":\"San Francisco, CA\"}"; + +pub async fn setup_mock_server(path_url: &str, response_body: Value) -> MockServer { + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path(path_url)) + .respond_with(ResponseTemplate::new(200).set_body_json(response_body)) + .mount(&mock_server) + .await; + mock_server +} + +pub fn create_mock_open_ai_response_with_tools(model_name: &str) -> Value { + json!({ + "id": "chatcmpl-123", + "object": "chat.completion", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": "call_123", + "type": "function", + "function": { + "name": TEST_TOOL_FUNCTION_NAME, + "arguments": TEST_TOOL_FUNCTION_ARGUMENTS + } + }] + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": TEST_INPUT_TOKENS, + "completion_tokens": TEST_OUTPUT_TOKENS, + "total_tokens": TEST_TOTAL_TOKENS + }, + "model": model_name + }) +} + +pub fn create_mock_open_ai_response(content: &str, model_name: &str) -> Value { + json!({ + "id": "chatcmpl-123", + "object": "chat.completion", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": content, + "tool_calls": null + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": TEST_INPUT_TOKENS, + "completion_tokens": TEST_OUTPUT_TOKENS, + "total_tokens": TEST_TOTAL_TOKENS + }, + "model": model_name + }) +} + +pub fn create_test_tool() -> Tool { + Tool::new( + "get_weather", + "Gets the current weather for a location", + json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. New York, NY" + } + }, + "required": ["location"] + }), + ) +} + +pub fn get_expected_function_call_arguments() -> Value { + json!({ + "location": "San Francisco, CA" + }) +} diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index ba3470221725..630d5a8df295 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -97,18 +97,18 @@ mod tests { use super::*; use crate::message::MessageContent; use crate::providers::configs::ModelConfig; + use crate::providers::mock_server::{ + create_mock_open_ai_response, create_mock_open_ai_response_with_tools, create_test_tool, + get_expected_function_call_arguments, setup_mock_server, TEST_INPUT_TOKENS, + TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS, + }; use rust_decimal_macros::dec; use serde_json::json; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; - async fn _setup_mock_server(response_body: Value) -> (MockServer, OpenAiProvider) { - let mock_server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/v1/chat/completions")) - .respond_with(ResponseTemplate::new(200).set_body_json(response_body)) - .mount(&mock_server) - .await; + async fn _setup_mock_response(response_body: Value) -> (MockServer, OpenAiProvider) { + let mock_server = setup_mock_server("/v1/chat/completions", response_body).await; // Create the OpenAiProvider with the mock server's URL as the host let config = OpenAiProviderConfig { @@ -123,28 +123,12 @@ mod tests { #[tokio::test] async fn test_complete_basic() -> Result<()> { + let model_name = "gpt-4o"; // Mock response for normal completion - let response_body = json!({ - "id": "chatcmpl-123", - "object": "chat.completion", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "Hello! How can I assist you today?", - "tool_calls": null - }, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": 12, - "completion_tokens": 15, - "total_tokens": 27 - }, - "model": "gpt-4o" - }); - - let (_, provider) = _setup_mock_server(response_body).await; + let response_body = + create_mock_open_ai_response("Hello! How can I assist you today?", model_name); + + let (_, provider) = _setup_mock_response(response_body).await; // Prepare input messages let messages = vec![Message::user().with_text("Hello?")]; @@ -160,10 +144,10 @@ mod tests { } else { panic!("Expected Text content"); } - assert_eq!(usage.usage.input_tokens, Some(12)); - assert_eq!(usage.usage.output_tokens, Some(15)); - assert_eq!(usage.usage.total_tokens, Some(27)); - assert_eq!(usage.model, "gpt-4o"); + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); + assert_eq!(usage.model, model_name); assert_eq!(usage.cost, Some(dec!(0.00018))); Ok(()) @@ -172,73 +156,36 @@ mod tests { #[tokio::test] async fn test_complete_tool_request() -> Result<()> { // Mock response for tool calling - let response_body = json!({ - "id": "chatcmpl-tool", - "object": "chat.completion", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": null, - "tool_calls": [{ - "id": "call_123", - "type": "function", - "function": { - "name": "get_weather", - "arguments": "{\"location\":\"San Francisco, CA\"}" - } - }] - }, - "finish_reason": "tool_calls" - }], - "usage": { - "prompt_tokens": 20, - "completion_tokens": 15, - "total_tokens": 35 - } - }); + let response_body = create_mock_open_ai_response_with_tools("gpt-4o"); - let (_, provider) = _setup_mock_server(response_body).await; + let (_, provider) = _setup_mock_response(response_body).await; // Input messages let messages = vec![Message::user().with_text("What's the weather in San Francisco?")]; // Define the tool using builder pattern - let tool = Tool::new( - "get_weather", - "Gets the current weather for a location", - json!({ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. New York, NY" - } - }, - "required": ["location"] - }), - ); // Call the complete method let (message, usage) = provider - .complete("You are a helpful assistant.", &messages, &[tool]) + .complete( + "You are a helpful assistant.", + &messages, + &[create_test_tool()], + ) .await?; // Assert the response if let MessageContent::ToolRequest(tool_request) = &message.content[0] { let tool_call = tool_request.tool_call.as_ref().unwrap(); - assert_eq!(tool_call.name, "get_weather"); - assert_eq!( - tool_call.arguments, - json!({"location": "San Francisco, CA"}) - ); + assert_eq!(tool_call.name, TEST_TOOL_FUNCTION_NAME); + assert_eq!(tool_call.arguments, get_expected_function_call_arguments()); } else { panic!("Expected ToolCall content"); } - assert_eq!(usage.usage.input_tokens, Some(20)); - assert_eq!(usage.usage.output_tokens, Some(15)); - assert_eq!(usage.usage.total_tokens, Some(35)); + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); Ok(()) } From 152ea971633f2a9e6048301223a83e56ac4bda14 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 19 Dec 2024 11:54:53 +1100 Subject: [PATCH 13/22] refactored databricks integration tests --- crates/goose/src/providers/databricks.rs | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index b918eefe665b..56d96580a9ca 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -182,6 +182,9 @@ mod tests { use super::*; use crate::message::MessageContent; use crate::providers::configs::ModelConfig; + use crate::providers::mock_server::{ + create_mock_open_ai_response, TEST_INPUT_TOKENS, TEST_OUTPUT_TOKENS, TEST_TOTAL_TOKENS, + }; use wiremock::matchers::{body_json, header, method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; @@ -191,19 +194,7 @@ mod tests { let mock_server = MockServer::start().await; // Mock response for completion - let mock_response = json!({ - "choices": [{ - "message": { - "role": "assistant", - "content": "Hello!" - } - }], - "usage": { - "prompt_tokens": 10, - "completion_tokens": 25, - "total_tokens": 35 - } - }); + let mock_response = create_mock_open_ai_response("Hello!", "my-databricks-model"); // Expected request body let system = "You are a helpful assistant."; @@ -247,9 +238,9 @@ mod tests { } else { panic!("Expected Text content"); } - assert_eq!(reply_usage.usage.input_tokens, Some(10)); - assert_eq!(reply_usage.usage.output_tokens, Some(25)); - assert_eq!(reply_usage.usage.total_tokens, Some(35)); + assert_eq!(reply_usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(reply_usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(reply_usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); Ok(()) } From aa6b67860e2670052f405ee20b460acb4775c266 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 19 Dec 2024 12:04:19 +1100 Subject: [PATCH 14/22] refactored ollma integration tests --- crates/goose/src/providers/databricks.rs | 2 +- crates/goose/src/providers/mock.rs | 3 - crates/goose/src/providers/mock_server.rs | 12 ++- crates/goose/src/providers/ollama.rs | 113 ++++++---------------- crates/goose/src/providers/openai.rs | 6 +- 5 files changed, 42 insertions(+), 94 deletions(-) diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 56d96580a9ca..a5d66678b129 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -194,7 +194,7 @@ mod tests { let mock_server = MockServer::start().await; // Mock response for completion - let mock_response = create_mock_open_ai_response("Hello!", "my-databricks-model"); + let mock_response = create_mock_open_ai_response("my-databricks-model", "Hello!"); // Expected request body let system = "You are a helpful assistant."; diff --git a/crates/goose/src/providers/mock.rs b/crates/goose/src/providers/mock.rs index 11505d9d8fad..270870b59dcd 100644 --- a/crates/goose/src/providers/mock.rs +++ b/crates/goose/src/providers/mock.rs @@ -2,15 +2,12 @@ use super::base::ProviderUsage; use crate::message::Message; use crate::providers::base::{Provider, Usage}; use crate::providers::configs::ModelConfig; -use crate::providers::openai::OpenAiProvider; use anyhow::Result; use async_trait::async_trait; use mcp_core::tool::Tool; use rust_decimal_macros::dec; -use serde_json::Value; use std::sync::Arc; use std::sync::Mutex; -use wiremock::MockServer; /// A mock provider that returns pre-configured responses for testing pub struct MockProvider { diff --git a/crates/goose/src/providers/mock_server.rs b/crates/goose/src/providers/mock_server.rs index fbc4305b7af0..0350742d79c2 100644 --- a/crates/goose/src/providers/mock_server.rs +++ b/crates/goose/src/providers/mock_server.rs @@ -1,3 +1,4 @@ +use axum::http::StatusCode; use mcp_core::Tool; use serde_json::{json, Value}; use wiremock::matchers::{method, path}; @@ -19,6 +20,15 @@ pub async fn setup_mock_server(path_url: &str, response_body: Value) -> MockServ mock_server } +pub async fn setup_mock_server_with_response_code(path_url: &str, response_code: u16) -> MockServer { + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path(path_url)) + .respond_with(ResponseTemplate::new(response_code)) + .mount(&mock_server) + .await; + mock_server +} pub fn create_mock_open_ai_response_with_tools(model_name: &str) -> Value { json!({ "id": "chatcmpl-123", @@ -48,7 +58,7 @@ pub fn create_mock_open_ai_response_with_tools(model_name: &str) -> Value { }) } -pub fn create_mock_open_ai_response(content: &str, model_name: &str) -> Value { +pub fn create_mock_open_ai_response(model_name: &str, content: &str) -> Value { json!({ "id": "chatcmpl-123", "object": "chat.completion", diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 833fc236d9b5..b5a4af49f1e2 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -75,18 +75,12 @@ impl Provider for OllamaProvider { mod tests { use super::*; use crate::message::MessageContent; - use serde_json::json; - use wiremock::matchers::{method, path}; - use wiremock::{Mock, MockServer, ResponseTemplate}; + use crate::providers::mock_server::{create_mock_open_ai_response, create_mock_open_ai_response_with_tools, create_test_tool, get_expected_function_call_arguments, setup_mock_server, setup_mock_server_with_response_code, TEST_INPUT_TOKENS, TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS}; + use rust_decimal_macros::dec; + use wiremock::MockServer; async fn _setup_mock_server(response_body: Value) -> (MockServer, OllamaProvider) { - let mock_server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/v1/chat/completions")) - .respond_with(ResponseTemplate::new(200).set_body_json(response_body)) - .mount(&mock_server) - .await; - + let mock_server = setup_mock_server("/v1/chat/completions", response_body).await; // Create the OllamaProvider with the mock server's URL as the host let config = OllamaProviderConfig { host: mock_server.uri(), @@ -99,25 +93,10 @@ mod tests { #[tokio::test] async fn test_complete_basic() -> Result<()> { + let model_name = "gpt-4o"; // Mock response for normal completion - let response_body = json!({ - "id": "chatcmpl-123", - "object": "chat.completion", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "Hello! How can I assist you today?", - "tool_calls": null - }, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": 12, - "completion_tokens": 15, - "total_tokens": 27 - } - }); + let response_body = + create_mock_open_ai_response(model_name, "Hello! How can I assist you today?"); let (_, provider) = _setup_mock_server(response_body).await; @@ -135,9 +114,11 @@ mod tests { } else { panic!("Expected Text content"); } - assert_eq!(usage.usage.input_tokens, Some(12)); - assert_eq!(usage.usage.output_tokens, Some(15)); - assert_eq!(usage.usage.total_tokens, Some(27)); + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); + assert_eq!(usage.model, model_name); + assert_eq!(usage.cost, Some(dec!(0.00018))); Ok(()) } @@ -145,81 +126,43 @@ mod tests { #[tokio::test] async fn test_complete_tool_request() -> Result<()> { // Mock response for tool calling - let response_body = json!({ - "id": "chatcmpl-tool", - "object": "chat.completion", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": null, - "tool_calls": [{ - "id": "call_h5d3s25w", - "type": "function", - "function": { - "name": "read_file", - "arguments": "{\"filename\":\"test.txt\"}" - } - }] - }, - "finish_reason": "tool_calls" - }], - "usage": { - "prompt_tokens": 63, - "completion_tokens": 70, - "total_tokens": 133 - } - }); + let response_body = create_mock_open_ai_response_with_tools("gpt-4o"); let (_, provider) = _setup_mock_server(response_body).await; // Input messages - let messages = vec![Message::user().with_text("Can you read the test.txt file?")]; - - // Define the tool - let tool = Tool::new( - "read_file", - "Read the content of a file", - json!({ - "type": "object", - "properties": { - "filename": { - "type": "string", - "description": "The name of the file to read" - } - }, - "required": ["filename"] - }), - ); + let messages = vec![Message::user().with_text("What's the weather in San Francisco?")]; + + // Define the tool using builder pattern // Call the complete method let (message, usage) = provider - .complete("You are a helpful assistant.", &messages, &[tool]) + .complete( + "You are a helpful assistant.", + &messages, + &[create_test_tool()], + ) .await?; + // Assert the response if let MessageContent::ToolRequest(tool_request) = &message.content[0] { let tool_call = tool_request.tool_call.as_ref().unwrap(); - assert_eq!(tool_call.name, "read_file"); - assert_eq!(tool_call.arguments, json!({"filename": "test.txt"})); + assert_eq!(tool_call.name, TEST_TOOL_FUNCTION_NAME); + assert_eq!(tool_call.arguments, get_expected_function_call_arguments()); } else { panic!("Expected ToolCall content"); } - assert_eq!(usage.usage.input_tokens, Some(63)); - assert_eq!(usage.usage.output_tokens, Some(70)); - assert_eq!(usage.usage.total_tokens, Some(133)); + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); Ok(()) } #[tokio::test] async fn test_server_error() -> Result<()> { - let mock_server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/v1/chat/completions")) - .respond_with(ResponseTemplate::new(500)) - .mount(&mock_server) - .await; + let mock_server = setup_mock_server_with_response_code("/v1/chat/completions", 500).await; let config = OllamaProviderConfig { host: mock_server.uri(), diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 630d5a8df295..54884b040dec 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -103,9 +103,7 @@ mod tests { TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS, }; use rust_decimal_macros::dec; - use serde_json::json; - use wiremock::matchers::{method, path}; - use wiremock::{Mock, MockServer, ResponseTemplate}; + use wiremock::MockServer; async fn _setup_mock_response(response_body: Value) -> (MockServer, OpenAiProvider) { let mock_server = setup_mock_server("/v1/chat/completions", response_body).await; @@ -126,7 +124,7 @@ mod tests { let model_name = "gpt-4o"; // Mock response for normal completion let response_body = - create_mock_open_ai_response("Hello! How can I assist you today?", model_name); + create_mock_open_ai_response(model_name, "Hello! How can I assist you today?"); let (_, provider) = _setup_mock_response(response_body).await; From d66e48a019338e98062c4b9ad273bb629161e3f7 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 19 Dec 2024 14:19:14 +1100 Subject: [PATCH 15/22] added groq integration test --- crates/goose/src/providers/groq.rs | 86 +++++++++++++++++++++++ crates/goose/src/providers/mock_server.rs | 1 - crates/goose/src/providers/ollama.rs | 2 - 3 files changed, 86 insertions(+), 3 deletions(-) diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index a550b44284c6..71e5b86a00d1 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -72,3 +72,89 @@ impl Provider for GroqProvider { Ok((message, ProviderUsage::new(model, usage, None))) } } + +mod tests { + use super::*; + use crate::message::MessageContent; + use crate::providers::mock_server::{create_mock_open_ai_response, create_mock_open_ai_response_with_tools, create_test_tool, get_expected_function_call_arguments, setup_mock_server, TEST_INPUT_TOKENS, TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS}; + use wiremock::MockServer; + + async fn _setup_mock_server(response_body: Value) -> (MockServer, GroqProvider) { + let mock_server = setup_mock_server("/openai/v1/chat/completions", response_body).await; + let config = GroqProviderConfig { + host: mock_server.uri(), + api_key: "test_api_key".to_string(), + model: ModelConfig::new(GROQ_DEFAULT_MODEL.to_string()), + }; + + let provider = GroqProvider::new(config).unwrap(); + (mock_server, provider) + } + + #[tokio::test] + async fn test_complete_basic() -> anyhow::Result<()> { + let model_name = "gpt-4o"; + // Mock response for normal completion + let response_body = + create_mock_open_ai_response(model_name, "Hello! How can I assist you today?"); + + let (_, provider) = _setup_mock_server(response_body).await; + + // Prepare input messages + let messages = vec![Message::user().with_text("Hello?")]; + + // Call the complete method + let (message, usage) = provider + .complete("You are a helpful assistant.", &messages, &[]) + .await?; + + // Assert the response + if let MessageContent::Text(text) = &message.content[0] { + assert_eq!(text.text, "Hello! How can I assist you today?"); + } else { + panic!("Expected Text content"); + } + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); + assert_eq!(usage.model, model_name); + assert_eq!(usage.cost, None); + + Ok(()) + } + + #[tokio::test] + async fn test_complete_tool_request() -> anyhow::Result<()> { + // Mock response for tool calling + let response_body = create_mock_open_ai_response_with_tools("gpt-4o"); + + let (_, provider) = _setup_mock_server(response_body).await; + + // Input messages + let messages = vec![Message::user().with_text("What's the weather in San Francisco?")]; + + // Call the complete method + let (message, usage) = provider + .complete( + "You are a helpful assistant.", + &messages, + &[create_test_tool()], + ) + .await?; + + // Assert the response + if let MessageContent::ToolRequest(tool_request) = &message.content[0] { + let tool_call = tool_request.tool_call.as_ref().unwrap(); + assert_eq!(tool_call.name, TEST_TOOL_FUNCTION_NAME); + assert_eq!(tool_call.arguments, get_expected_function_call_arguments()); + } else { + panic!("Expected ToolCall content"); + } + + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); + + Ok(()) + } +} diff --git a/crates/goose/src/providers/mock_server.rs b/crates/goose/src/providers/mock_server.rs index 0350742d79c2..4c24801e1043 100644 --- a/crates/goose/src/providers/mock_server.rs +++ b/crates/goose/src/providers/mock_server.rs @@ -1,4 +1,3 @@ -use axum::http::StatusCode; use mcp_core::Tool; use serde_json::{json, Value}; use wiremock::matchers::{method, path}; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index b5a4af49f1e2..b9c4439027ee 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -133,8 +133,6 @@ mod tests { // Input messages let messages = vec![Message::user().with_text("What's the weather in San Francisco?")]; - // Define the tool using builder pattern - // Call the complete method let (message, usage) = provider .complete( From 13d896796e871b350530c5681b4122f886a9c2f5 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 19 Dec 2024 14:55:37 +1100 Subject: [PATCH 16/22] added google provider integration tests --- crates/goose/src/providers.rs | 2 +- crates/goose/src/providers/google.rs | 84 +++++++++++++++++++++++ crates/goose/src/providers/groq.rs | 7 +- crates/goose/src/providers/mock_server.rs | 51 +++++++++++++- crates/goose/src/providers/ollama.rs | 7 +- 5 files changed, 147 insertions(+), 4 deletions(-) diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index 08e29029e2e7..6f2fb5b9152f 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -15,4 +15,4 @@ pub mod groq; #[cfg(test)] pub mod mock; #[cfg(test)] -mod mock_server; +pub mod mock_server; diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 3ed83b2722f9..6c6ec433c6c7 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -349,6 +349,9 @@ impl Provider for GoogleProvider { mod tests { use super::*; use crate::errors::AgentResult; + use crate::providers::mock_server::{create_mock_google_ai_response, create_mock_google_ai_response_with_tools, create_test_tool, get_expected_function_call_arguments, setup_mock_server, TEST_INPUT_TOKENS, TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS}; + use wiremock::MockServer; + fn set_up_provider() -> GoogleProvider { let provider_config = GoogleProviderConfig { host: "dummy_host".to_string(), @@ -605,4 +608,85 @@ mod tests { panic!("Expected valid tool request"); } } + + async fn _setup_mock_server(model_name: &str, response_body: Value) -> (MockServer, GoogleProvider) { + let path_url = format!("/v1beta/models/{}:generateContent", model_name); + let mock_server = setup_mock_server(&path_url, response_body).await; + let config = GoogleProviderConfig { + host: mock_server.uri(), + api_key: "test_api_key".to_string(), + model: ModelConfig::new(GOOGLE_DEFAULT_MODEL.to_string()), + }; + + let provider = GoogleProvider::new(config).unwrap(); + (mock_server, provider) + } + + #[tokio::test] + async fn test_complete_basic() -> anyhow::Result<()> { + let model_name = "gemini-1.5-flash"; + // Mock response for normal completion + let response_body = + create_mock_google_ai_response(model_name, "Hello! How can I assist you today?"); + + let (_, provider) = _setup_mock_server(model_name, response_body).await; + + // Prepare input messages + let messages = vec![Message::user().with_text("Hello?")]; + + // Call the complete method + let (message, usage) = provider + .complete("You are a helpful assistant.", &messages, &[]) + .await?; + + // Assert the response + if let MessageContent::Text(text) = &message.content[0] { + assert_eq!(text.text, "Hello! How can I assist you today?"); + } else { + panic!("Expected Text content"); + } + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); + assert_eq!(usage.model, model_name); + assert_eq!(usage.cost, None); + + Ok(()) + } + + #[tokio::test] + async fn test_complete_tool_request() -> anyhow::Result<()> { + let model_name = "gemini-1.5-flash"; + // Mock response for tool calling + let response_body = create_mock_google_ai_response_with_tools("gpt-4o"); + + let (_, provider) = _setup_mock_server(model_name, response_body).await; + + // Input messages + let messages = vec![Message::user().with_text("What's the weather in San Francisco?")]; + + // Call the complete method + let (message, usage) = provider + .complete( + "You are a helpful assistant.", + &messages, + &[create_test_tool()], + ) + .await?; + + // Assert the response + if let MessageContent::ToolRequest(tool_request) = &message.content[0] { + let tool_call = tool_request.tool_call.as_ref().unwrap(); + assert_eq!(tool_call.name, TEST_TOOL_FUNCTION_NAME); + assert_eq!(tool_call.arguments, get_expected_function_call_arguments()); + } else { + panic!("Expected ToolCall content"); + } + + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); + + Ok(()) + } } diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index 71e5b86a00d1..070a866e9471 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -73,10 +73,15 @@ impl Provider for GroqProvider { } } +#[cfg(test)] mod tests { use super::*; use crate::message::MessageContent; - use crate::providers::mock_server::{create_mock_open_ai_response, create_mock_open_ai_response_with_tools, create_test_tool, get_expected_function_call_arguments, setup_mock_server, TEST_INPUT_TOKENS, TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS}; + use crate::providers::mock_server::{ + create_mock_open_ai_response, create_mock_open_ai_response_with_tools, create_test_tool, + get_expected_function_call_arguments, setup_mock_server, TEST_INPUT_TOKENS, + TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS, + }; use wiremock::MockServer; async fn _setup_mock_server(response_body: Value) -> (MockServer, GroqProvider) { diff --git a/crates/goose/src/providers/mock_server.rs b/crates/goose/src/providers/mock_server.rs index 4c24801e1043..8712cb8635c7 100644 --- a/crates/goose/src/providers/mock_server.rs +++ b/crates/goose/src/providers/mock_server.rs @@ -19,7 +19,10 @@ pub async fn setup_mock_server(path_url: &str, response_body: Value) -> MockServ mock_server } -pub async fn setup_mock_server_with_response_code(path_url: &str, response_code: u16) -> MockServer { +pub async fn setup_mock_server_with_response_code( + path_url: &str, + response_code: u16, +) -> MockServer { let mock_server = MockServer::start().await; Mock::given(method("POST")) .and(path(path_url)) @@ -57,6 +60,52 @@ pub fn create_mock_open_ai_response_with_tools(model_name: &str) -> Value { }) } +pub fn create_mock_google_ai_response_with_tools(model_name: &str) -> Value { + json!({ + "candidates": [{ + "content": { + "parts": [{ + "functionCall": { + "name": TEST_TOOL_FUNCTION_NAME, + "args":{ + "location": "San Francisco, CA" + } + + } + }], + "role": "model" + }, + "finishReason": "STOP" + }], + "modelVersion": model_name, + "usageMetadata": { + "candidatesTokenCount": TEST_OUTPUT_TOKENS, + "promptTokenCount": TEST_INPUT_TOKENS, + "totalTokenCount": TEST_TOTAL_TOKENS + } + }) +} + +pub fn create_mock_google_ai_response(model_name: &str, content: &str) -> Value { + json!({ + "candidates": [{ + "content": { + "parts": [{ + "text": content + }], + "role": "model" + }, + "finishReason": "STOP" + }], + "modelVersion": model_name, + "usageMetadata": { + "candidatesTokenCount": TEST_OUTPUT_TOKENS, + "promptTokenCount": TEST_INPUT_TOKENS, + "totalTokenCount": TEST_TOTAL_TOKENS + } + }) +} + pub fn create_mock_open_ai_response(model_name: &str, content: &str) -> Value { json!({ "id": "chatcmpl-123", diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index b9c4439027ee..cb9bb1449242 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -75,7 +75,12 @@ impl Provider for OllamaProvider { mod tests { use super::*; use crate::message::MessageContent; - use crate::providers::mock_server::{create_mock_open_ai_response, create_mock_open_ai_response_with_tools, create_test_tool, get_expected_function_call_arguments, setup_mock_server, setup_mock_server_with_response_code, TEST_INPUT_TOKENS, TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS}; + use crate::providers::mock_server::{ + create_mock_open_ai_response, create_mock_open_ai_response_with_tools, create_test_tool, + get_expected_function_call_arguments, setup_mock_server, + setup_mock_server_with_response_code, TEST_INPUT_TOKENS, TEST_OUTPUT_TOKENS, + TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS, + }; use rust_decimal_macros::dec; use wiremock::MockServer; From 5cd0c061323f1ef9aea28e20c8f3d83abef238dd Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 19 Dec 2024 15:01:37 +1100 Subject: [PATCH 17/22] used util functions --- crates/goose/src/providers/databricks.rs | 44 +++--------------------- crates/goose/src/providers/google.rs | 11 ++++-- 2 files changed, 14 insertions(+), 41 deletions(-) diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index a5d66678b129..da2445f1dc0c 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -8,11 +8,11 @@ use super::base::{Provider, ProviderUsage, Usage}; use super::configs::{DatabricksAuth, DatabricksProviderConfig, ModelConfig, ProviderModelConfig}; use super::model_pricing::{cost, model_pricing_for}; use super::oauth; -use super::utils::{check_bedrock_context_length_error, get_model}; +use super::utils::{check_bedrock_context_length_error, get_model, handle_response}; use crate::message::Message; use crate::providers::openai_utils::{ - check_openai_context_length_error, messages_to_openai_spec, openai_response_to_message, - tools_to_openai_spec, + check_openai_context_length_error, get_openai_usage, messages_to_openai_spec, + openai_response_to_message, tools_to_openai_spec, }; use mcp_core::tool::Tool; @@ -49,30 +49,7 @@ impl DatabricksProvider { } fn get_usage(data: &Value) -> Result { - let usage = data - .get("usage") - .ok_or_else(|| anyhow!("No usage data in response"))?; - - let input_tokens = usage - .get("prompt_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let output_tokens = usage - .get("completion_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let total_tokens = usage - .get("total_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32) - .or_else(|| match (input_tokens, output_tokens) { - (Some(input), Some(output)) => Some(input + output), - _ => None, - }); - - Ok(Usage::new(input_tokens, output_tokens, total_tokens)) + get_openai_usage(data) } async fn post(&self, payload: Value) -> Result { @@ -91,18 +68,7 @@ impl DatabricksProvider { .send() .await?; - match response.status() { - StatusCode::OK => Ok(response.json().await?), - status if status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error() => { - // Implement retry logic here if needed - Err(anyhow!("Server error: {}", status)) - } - _ => { - let status = response.status(); - let err_text = response.text().await.unwrap_or_default(); - Err(anyhow!("Request failed: {}: {}", status, err_text)) - } - } + handle_response(payload, response).await? } } diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 6c6ec433c6c7..d96b681add87 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -349,7 +349,11 @@ impl Provider for GoogleProvider { mod tests { use super::*; use crate::errors::AgentResult; - use crate::providers::mock_server::{create_mock_google_ai_response, create_mock_google_ai_response_with_tools, create_test_tool, get_expected_function_call_arguments, setup_mock_server, TEST_INPUT_TOKENS, TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS}; + use crate::providers::mock_server::{ + create_mock_google_ai_response, create_mock_google_ai_response_with_tools, + create_test_tool, get_expected_function_call_arguments, setup_mock_server, + TEST_INPUT_TOKENS, TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS, + }; use wiremock::MockServer; fn set_up_provider() -> GoogleProvider { @@ -609,7 +613,10 @@ mod tests { } } - async fn _setup_mock_server(model_name: &str, response_body: Value) -> (MockServer, GoogleProvider) { + async fn _setup_mock_server( + model_name: &str, + response_body: Value, + ) -> (MockServer, GoogleProvider) { let path_url = format!("/v1beta/models/{}:generateContent", model_name); let mock_server = setup_mock_server(&path_url, response_body).await; let config = GoogleProviderConfig { From ab5279efcbb2b3695474f8f5f4a72085c763c171 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 20 Dec 2024 06:54:45 +1100 Subject: [PATCH 18/22] only concat tool response content for ollama and groq --- crates/goose/src/providers/databricks.rs | 2 +- crates/goose/src/providers/openai_utils.rs | 56 +++++++++++++++------- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index da2445f1dc0c..fa5d83c82ff8 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -81,7 +81,7 @@ impl Provider for DatabricksProvider { tools: &[Tool], ) -> Result<(Message, ProviderUsage)> { // Prepare messages and tools - let messages_spec = messages_to_openai_spec(messages, &self.config.image_format); + let messages_spec = messages_to_openai_spec(messages, &self.config.image_format, false); let tools_spec = if !tools.is_empty() { tools_to_openai_spec(tools)? } else { diff --git a/crates/goose/src/providers/openai_utils.rs b/crates/goose/src/providers/openai_utils.rs index 2793ed07927b..15a08505f147 100644 --- a/crates/goose/src/providers/openai_utils.rs +++ b/crates/goose/src/providers/openai_utils.rs @@ -13,7 +13,11 @@ use serde_json::{json, Value}; /// Convert internal Message format to OpenAI's API message specification /// some openai compatible endpoints use the anthropic image spec at the content level /// even though the message structure is otherwise following openai, the enum switches this -pub fn messages_to_openai_spec(messages: &[Message], image_format: &ImageFormat) -> Vec { +pub fn messages_to_openai_spec( + messages: &[Message], + image_format: &ImageFormat, + concat_tool_response_contents: bool, +) -> Vec { let mut messages_spec = Vec::new(); for message in messages { let mut converted = json!({ @@ -90,20 +94,31 @@ pub fn messages_to_openai_spec(messages: &[Message], image_format: &ImageFormat) } } } - let concatenated_content = tool_content - .iter() - .map(|content| match content { - Content::Text(text) => text.text.clone(), - _ => String::new(), - }) - .collect::>() - .join(" "); - // First add the tool response with all content - output.push(json!({ - "role": "tool", - "content": concatenated_content, - "tool_call_id": response.id - })); + match concat_tool_response_contents { + true => { + let concatenated_content = tool_content + .iter() + .map(|content| match content { + Content::Text(text) => text.text.clone(), + _ => String::new(), + }) + .collect::>() + .join(" "); + // First add the tool response with all content + output.push(json!({ + "role": "tool", + "content": concatenated_content, + "tool_call_id": response.id + })); + } + false => { + output.push(json!({ + "role": "tool", + "content": tool_content, + "tool_call_id": response.id + })); + } + }; // Then add any image messages that need to follow output.extend(image_messages); } @@ -246,13 +261,18 @@ pub fn create_openai_request_payload( system: &str, messages: &[Message], tools: &[Tool], + concat_tool_response_contents: bool, ) -> anyhow::Result { let system_message = json!({ "role": "system", "content": system }); - let messages_spec = messages_to_openai_spec(messages, &ImageFormat::OpenAi); + let messages_spec = messages_to_openai_spec( + messages, + &ImageFormat::OpenAi, + concat_tool_response_contents, + ); let tools_spec = tools_to_openai_spec(tools)?; let mut messages_array = vec![system_message]; @@ -327,7 +347,7 @@ mod tests { #[test] fn test_messages_to_openai_spec() -> anyhow::Result<()> { let message = Message::user().with_text("Hello"); - let spec = messages_to_openai_spec(&[message], &ImageFormat::OpenAi); + let spec = messages_to_openai_spec(&[message], &ImageFormat::OpenAi, false); assert_eq!(spec.len(), 1); assert_eq!(spec[0]["role"], "user"); @@ -381,7 +401,7 @@ mod tests { messages .push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]))); - let spec = messages_to_openai_spec(&messages, &ImageFormat::OpenAi); + let spec = messages_to_openai_spec(&messages, &ImageFormat::OpenAi, false); assert_eq!(spec.len(), 4); assert_eq!(spec[0]["role"], "assistant"); From 7f4a90f83059bfddaf8dbc908b52886bbc5a6281 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 20 Dec 2024 07:49:56 +1100 Subject: [PATCH 19/22] only concat tool response content for ollama and groq --- crates/goose/src/providers/databricks.rs | 2 +- crates/goose/src/providers/groq.rs | 3 ++- crates/goose/src/providers/ollama.rs | 3 ++- crates/goose/src/providers/openai.rs | 3 ++- crates/goose/src/providers/openai_utils.rs | 27 +++++++++------------- 5 files changed, 18 insertions(+), 20 deletions(-) diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index fa5d83c82ff8..6a2670392ea1 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -1,6 +1,6 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; -use reqwest::{Client, StatusCode}; +use reqwest::Client; use serde_json::{json, Value}; use std::time::Duration; diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index 070a866e9471..b0d1be1dc2a7 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -61,7 +61,8 @@ impl Provider for GroqProvider { messages: &[Message], tools: &[Tool], ) -> anyhow::Result<(Message, ProviderUsage)> { - let payload = create_openai_request_payload(&self.config.model, system, messages, tools)?; + let payload = + create_openai_request_payload(&self.config.model, system, messages, tools, true)?; let response = self.post(payload).await?; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index cb9bb1449242..961e100b7b5f 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -57,7 +57,8 @@ impl Provider for OllamaProvider { messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage)> { - let payload = create_openai_request_payload(&self.config.model, system, messages, tools)?; + let payload = + create_openai_request_payload(&self.config.model, system, messages, tools, false)?; let response = self.post(payload).await?; diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 54884b040dec..f735b2209c64 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -69,7 +69,8 @@ impl Provider for OpenAiProvider { tools: &[Tool], ) -> Result<(Message, ProviderUsage)> { // Not checking for o1 model here since system message is not supported by o1 - let payload = create_openai_request_payload(&self.config.model, system, messages, tools)?; + let payload = + create_openai_request_payload(&self.config.model, system, messages, tools, false)?; // Make request let response = self.post(payload).await?; diff --git a/crates/goose/src/providers/openai_utils.rs b/crates/goose/src/providers/openai_utils.rs index 15a08505f147..d8b4a34ef5ea 100644 --- a/crates/goose/src/providers/openai_utils.rs +++ b/crates/goose/src/providers/openai_utils.rs @@ -94,31 +94,26 @@ pub fn messages_to_openai_spec( } } } - match concat_tool_response_contents { + let tool_response_content: Value = match concat_tool_response_contents { true => { - let concatenated_content = tool_content + json!(tool_content .iter() .map(|content| match content { Content::Text(text) => text.text.clone(), _ => String::new(), }) .collect::>() - .join(" "); - // First add the tool response with all content - output.push(json!({ - "role": "tool", - "content": concatenated_content, - "tool_call_id": response.id - })); - } - false => { - output.push(json!({ - "role": "tool", - "content": tool_content, - "tool_call_id": response.id - })); + .join(" ")) } + false => json!(tool_content), }; + + // First add the tool response with all content + output.push(json!({ + "role": "tool", + "content": tool_response_content, + "tool_call_id": response.id + })); // Then add any image messages that need to follow output.extend(image_messages); } From 529f5d1a41615b494205e805ad11b59da2c651f6 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 20 Dec 2024 07:58:55 +1100 Subject: [PATCH 20/22] fixed the test --- crates/goose/src/providers/openai_utils.rs | 31 +++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/crates/goose/src/providers/openai_utils.rs b/crates/goose/src/providers/openai_utils.rs index d8b4a34ef5ea..0be8d8917448 100644 --- a/crates/goose/src/providers/openai_utils.rs +++ b/crates/goose/src/providers/openai_utils.rs @@ -396,7 +396,7 @@ mod tests { messages .push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]))); - let spec = messages_to_openai_spec(&messages, &ImageFormat::OpenAi, false); + let spec = messages_to_openai_spec(&messages, &ImageFormat::OpenAi, true); assert_eq!(spec.len(), 4); assert_eq!(spec[0]["role"], "assistant"); @@ -412,6 +412,35 @@ mod tests { Ok(()) } + #[test] + fn test_messages_to_openai_spec_not_concat_tool_response_content() -> anyhow::Result<()> { + let mut messages = vec![Message::assistant().with_tool_request( + "tool1", + Ok(ToolCall::new("example", json!({"param1": "value1"}))), + )]; + + // Get the ID from the tool request to use in the response + let tool_id = if let MessageContent::ToolRequest(request) = &messages[0].content[0] { + request.id.clone() + } else { + panic!("should be tool request"); + }; + + messages + .push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]))); + + let spec = messages_to_openai_spec(&messages, &ImageFormat::OpenAi, false); + + assert_eq!(spec.len(), 2); + assert_eq!(spec[0]["role"], "assistant"); + assert!(spec[0]["tool_calls"].is_array()); + assert_eq!(spec[1]["role"], "tool"); + assert_eq!(spec[1]["content"][0]["text"], "Result"); + assert_eq!(spec[1]["tool_call_id"], spec[0]["tool_calls"][0]["id"]); + + Ok(()) + } + #[test] fn test_tools_to_openai_spec_duplicate() -> anyhow::Result<()> { let tool1 = Tool::new( From 2b2dad7a6e945406d32a393930b6490d34337901 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 20 Dec 2024 08:38:29 +1100 Subject: [PATCH 21/22] fixed the format --- crates/goose/src/providers/ollama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 961e100b7b5f..0efefc017ada 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -124,7 +124,7 @@ mod tests { assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); assert_eq!(usage.model, model_name); - assert_eq!(usage.cost, Some(dec!(0.00018))); + assert_eq!(usage.cost, None); Ok(()) } From 0674f308d425f18db5582df8b70a98d181b49e13 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 20 Dec 2024 09:11:16 +1100 Subject: [PATCH 22/22] clean up --- crates/goose/src/providers/ollama.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 0efefc017ada..540f9291f862 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -82,7 +82,6 @@ mod tests { setup_mock_server_with_response_code, TEST_INPUT_TOKENS, TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS, }; - use rust_decimal_macros::dec; use wiremock::MockServer; async fn _setup_mock_server(response_body: Value) -> (MockServer, OllamaProvider) {