From 971944306bdedcfc993c61878cc373b160d942f5 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 17 Dec 2024 08:52:49 +1100 Subject: [PATCH 01/17] added google provider --- crates/goose-cli/src/commands/configure.rs | 3 + crates/goose-cli/src/profile.rs | 17 +- crates/goose-server/src/configuration.rs | 26 +++ crates/goose-server/src/state.rs | 9 + crates/goose/src/providers.rs | 1 + crates/goose/src/providers/configs.rs | 11 + crates/goose/src/providers/factory.rs | 5 + crates/goose/src/providers/google.rs | 253 +++++++++++++++++++++ crates/goose/src/providers/utils.rs | 4 +- crates/goose/src/token_counter.rs | 7 +- 10 files changed, 329 insertions(+), 7 deletions(-) create mode 100644 crates/goose/src/providers/google.rs diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index 67a5bc1f7b8d..10934d278c87 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -47,6 +47,7 @@ pub async fn handle_configure( ("databricks", "Databricks", "Models on AI Gateway"), ("ollama", "Ollama", "Local open source models"), ("anthropic", "Anthropic", "Claude models"), + ("google", "Google Gemini", "Gemini models"), ]) .interact()? .to_string() @@ -153,6 +154,7 @@ pub fn get_recommended_model(provider_name: &str) -> &str { "databricks" => "claude-3-5-sonnet-2", "ollama" => OLLAMA_MODEL, "anthropic" => "claude-3-5-sonnet-2", + "google" => "gemini-1.5-flash-8b", _ => panic!("Invalid provider name"), } } @@ -163,6 +165,7 @@ pub fn get_required_keys(provider_name: &str) -> Vec<&'static str> { "databricks" => vec!["DATABRICKS_HOST"], "ollama" => vec!["OLLAMA_HOST"], "anthropic" => vec!["ANTHROPIC_API_KEY"], // Removed ANTHROPIC_HOST since we use a fixed endpoint + "google" => vec!["GOOGLE_API_KEY"], _ => panic!("Invalid provider name"), } } diff --git a/crates/goose-cli/src/profile.rs b/crates/goose-cli/src/profile.rs index c1bcb33e26c0..8bc4f521362e 100644 --- a/crates/goose-cli/src/profile.rs +++ b/crates/goose-cli/src/profile.rs @@ -1,9 +1,6 @@ use anyhow::Result; use goose::key_manager::{get_keyring_secret, KeyRetrievalStrategy}; -use goose::providers::configs::{ - AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, OllamaProviderConfig, - OpenAiProviderConfig, ProviderConfig, -}; +use goose::providers::configs::{AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, GoogleProviderConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fs; @@ -122,6 +119,18 @@ pub fn get_provider_config(provider_name: &str, model: String) -> ProviderConfig max_tokens: None, }) } + "google" => { + let api_key = get_keyring_secret("GOOGLE_API_KEY", KeyRetrievalStrategy::Both) + .expect("GOOGLE_API_KEY not available in env or the keychain\nSet an env var or rerun `goose configure`"); + + ProviderConfig::Google(GoogleProviderConfig { + host: "https://generativelanguage.googleapis.com".to_string(), // Default Anthropic API endpoint + api_key, + model, + temperature: None, + max_tokens: None, + }) + } _ => panic!("Invalid provider name"), } } diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index c4e2d706c8fe..a1fb929be130 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -11,6 +11,7 @@ use goose::providers::{ }; use serde::Deserialize; use std::net::SocketAddr; +use goose::providers::configs::GoogleProviderConfig; #[derive(Debug, Default, Deserialize)] pub struct ServerSettings { @@ -64,6 +65,17 @@ pub enum ProviderSettings { #[serde(default)] max_tokens: Option, }, + Google { + #[serde(default = "default_google_host")] + host: String, + api_key: String, + #[serde(default = "default_google_model")] + model: String, + #[serde(default)] + temperature: Option, + #[serde(default)] + max_tokens: Option, + }, } impl ProviderSettings { @@ -74,6 +86,7 @@ impl ProviderSettings { ProviderSettings::OpenAi { .. } => ProviderType::OpenAi, ProviderSettings::Databricks { .. } => ProviderType::Databricks, ProviderSettings::Ollama { .. } => ProviderType::Ollama, + ProviderSettings::Google { .. } => ProviderType::Google, } } @@ -118,6 +131,19 @@ impl ProviderSettings { temperature, max_tokens, }), + ProviderSettings::Google{ + host, + api_key, + model, + temperature, + max_tokens, + } => ProviderConfig::Google(GoogleProviderConfig { + host, + api_key, + model, + temperature, + max_tokens, + }), } } } diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index 10ad3e4e2502..3078bbc94d61 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -72,6 +72,15 @@ impl Clone for AppState { max_tokens: config.max_tokens, }) } + ProviderConfig::Google(config) => { + ProviderConfig::Google(goose::providers::configs::GoogleProviderConfig { + host: config.host.clone(), + api_key: config.api_key.clone(), + model: config.model.clone(), + temperature: config.temperature, + max_tokens: config.max_tokens, + }) + } }, agent: self.agent.clone(), secret_key: self.secret_key.clone(), diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index 18834d14b7ae..1839a8c54d27 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -10,3 +10,4 @@ pub mod utils; #[cfg(test)] pub mod mock; +mod google; diff --git a/crates/goose/src/providers/configs.rs b/crates/goose/src/providers/configs.rs index 91a827909eda..04c00b6d6141 100644 --- a/crates/goose/src/providers/configs.rs +++ b/crates/goose/src/providers/configs.rs @@ -11,6 +11,7 @@ pub enum ProviderConfig { Databricks(DatabricksProviderConfig), Ollama(OllamaProviderConfig), Anthropic(AnthropicProviderConfig), + Google(GoogleProviderConfig), } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -81,6 +82,16 @@ pub struct OpenAiProviderConfig { pub max_tokens: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GoogleProviderConfig { + pub host: String, + pub api_key: String, + pub model: String, + pub temperature: Option, + pub max_tokens: Option, +} + + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OllamaProviderConfig { pub host: String, diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index 46f5b3ff8382..f34fe22dee54 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -4,6 +4,7 @@ use super::{ }; use anyhow::Result; use strum_macros::EnumIter; +use crate::providers::google::GoogleProvider; #[derive(EnumIter, Debug)] pub enum ProviderType { @@ -11,6 +12,7 @@ pub enum ProviderType { Databricks, Ollama, Anthropic, + Google, } pub fn get_provider(config: ProviderConfig) -> Result> { @@ -23,5 +25,8 @@ pub fn get_provider(config: ProviderConfig) -> Result { Ok(Box::new(AnthropicProvider::new(anthropic_config)?)) } + ProviderConfig::Google(google_config) => { + Ok(Box::new(GoogleProvider::new(google_config)?)) + } } } diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs new file mode 100644 index 000000000000..cf248bacac10 --- /dev/null +++ b/crates/goose/src/providers/google.rs @@ -0,0 +1,253 @@ +use std::time::Duration; +use anyhow::anyhow; +use async_trait::async_trait; +use reqwest::{Client, StatusCode}; +use serde_json::{json, Value, Map}; +use mcp_core::{Content, Role, Tool, ToolCall}; +use crate::errors::AgentError; +use crate::message::{Message, MessageContent}; +use crate::providers::base::{Provider, ProviderUsageCollector, Usage}; +use crate::providers::configs::GoogleProviderConfig; +use crate::providers::utils::{is_valid_function_name}; + +pub struct GoogleProvider { + client: Client, + config: GoogleProviderConfig, + usage_collector: ProviderUsageCollector, +} + +impl GoogleProvider { + pub fn new(config: GoogleProviderConfig) -> anyhow::Result { + let client = Client::builder() + .timeout(Duration::from_secs(600)) // 10 minutes timeout + .build()?; + + Ok(Self { + client, + config, + usage_collector: ProviderUsageCollector::new(), + }) + } + + fn get_usage(data: &Value) -> anyhow::Result { + let usage = data + .get("usage") + .ok_or_else(|| anyhow!("No usage data in response"))?; + + let input_tokens = usage + .get("prompt_tokens") + .and_then(|v| v.as_i64()) + .map(|v| v as i32); + + let output_tokens = usage + .get("completion_tokens") + .and_then(|v| v.as_i64()) + .map(|v| v as i32); + + let total_tokens = usage + .get("total_tokens") + .and_then(|v| v.as_i64()) + .map(|v| v as i32) + .or_else(|| match (input_tokens, output_tokens) { + (Some(input), Some(output)) => Some(input + output), + _ => None, + }); + + Ok(Usage::new(input_tokens, output_tokens, total_tokens)) + } + + async fn post(&self, payload: Value) -> anyhow::Result { + let url = format!( + "{}/v1beta/models/{}:generateContent?key={}", + self.config.host.trim_end_matches('/'), + self.config.model, + self.config.api_key + ); + + let response = self + .client + .post(&url) + .header("CONTENT_TYPE", "application/json") + .json(&payload) + .send() + .await?; + + match response.status() { + StatusCode::OK => Ok(response.json().await?), + status if status == StatusCode::TOO_MANY_REQUESTS || status.as_u16() >= 500 => { + // Implement retry logic here if needed + Err(anyhow!("Server error: {}", status)) + } + _ => Err(anyhow!( + "Request failed: {}\nPayload: {}", + response.status(), + payload + )), + } + } +} + +#[async_trait] +impl Provider for GoogleProvider { + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> anyhow::Result<(Message, Usage)> { + // Lifei: TODO: temperature parameters, tools may be empty, images + let mut payload = Map::new(); + payload.insert("system_instruction".to_string(), json!({"parts": [{"text": system}]})); + payload.insert("contents".to_string(), json!(messages_to_google_spec(&messages))); + if !tools.is_empty() { + payload.insert("tools".to_string(), json!({"functionDeclarations": tools_to_google_spec(&tools)})); + } + + // Make request + let response = self.post(serde_json::Value::Object(payload)).await?; + + // Lifei: TODO handle api errors https://ai.google.dev/gemini-api/docs/troubleshooting?lang=python + // // Raise specific error if context length is exceeded + // if let Some(error) = response.get("error") { + // if let Some(err) = check_openai_context_length_error(error) { + // return Err(err.into()); + // } + // return Err(anyhow!("OpenAI API error: {}", error)); + // } + + // Parse response + let message = google_response_to_message(response.clone())?; + let usage = Usage::new(Some(100), Some(100), Some(100)); + // let usage = Self::get_usage(&response)?; + // self.usage_collector.add_usage(usage.clone()); + + Ok((message, usage)) + } + + + + fn total_usage(&self) -> Usage { + self.usage_collector.get_usage() + } +} + +fn messages_to_google_spec(messages: &[Message]) -> Vec { + messages + .iter() + .map(|message| { + let role = if message.role == Role::User { "user" } else { "model" }; + let mut parts = Vec::new(); + for message_content in message.content.iter() { + match message_content { + MessageContent::Text(text) => { + if !text.text.is_empty() { + parts.push(json!({"text": text.text})); + } + } + MessageContent::ToolRequest(request) => match &request.tool_call { + Ok(tool_call) => { + parts.push(json!({ + "functionCall": { + "name": tool_call.name, + "arguments": tool_call.arguments, + } + })); + } + Err(e) => { + parts.push(json!({"text":format!("Error: {}", e)})); + } + } + MessageContent::ToolResponse(response) => { + match &response.tool_result { + Ok(contents) => { + // Send only contents with no audience or with Assistant in the audience + let abridged: Vec<_> = contents + .iter() + .filter(|content| { + content + .audience() + .is_none_or(|audience| audience.contains(&Role::Assistant)) + }) + .map(|content| content.unannotated()) + .collect(); + + for content in abridged { + match content { + Content::Image(image) => { + } + _ => { + parts.push(json!({ + "functionResponse": { + "name": response.id, + "response": {"content": content}, + }} + )); + } + } + } + } + Err(e) => { + parts.push(json!({"text":format!("Error: {}", e)})); + } + } + } + + _ => {} + } + } + json!({"role": role, "parts": parts}) + }) + .collect() +} + +fn tools_to_google_spec(tools: &[Tool]) -> Vec { + tools + .iter() + .map(|tool| { + json!({ + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema, + }) + }) + .collect() +} + +fn google_response_to_message(response: Value) -> anyhow::Result { + let mut content = Vec::new(); + let binding = vec![]; + let candidates: &Vec = response.get("candidates").and_then(|v| v.as_array()).unwrap_or(&binding); + let candidate = candidates.get(0); + let role = Role::Assistant; + let created = chrono::Utc::now().timestamp(); + if candidate.is_none() { + return Ok(Message { role, created, content}); + } + let candidate = candidate.unwrap(); + let parts = candidate.get("content") + .and_then(|content| content.get("parts")).and_then(|parts| parts.as_array()).unwrap_or(&binding); + for part in parts { + if let Some(text) = part.get("text").and_then(|v| v.as_str()) { + content.push(MessageContent::text(text.to_string())); + } else if let Some(function_call) = part.get("functionCall") { + let id = function_call["name"].as_str().unwrap_or_default().to_string(); + let name = function_call["name"].as_str().unwrap_or_default().to_string(); + if !is_valid_function_name(&name) { + let error = AgentError::ToolNotFound(format!( + "The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+", + name + )); + content.push(MessageContent::tool_request(id, Err(error))); + } else { + let parameters = function_call.get("arguments"); + if parameters.is_some() { + content.push(MessageContent::tool_request( + id, + Ok(ToolCall::new(&name, parameters.unwrap().clone())))); + } + } + + } + } + Ok(Message { role, created, content}) +} diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index b337603faa9c..46fd4b6a742b 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -234,12 +234,12 @@ pub fn openai_response_to_message(response: Value) -> Result { }) } -fn sanitize_function_name(name: &str) -> String { +pub fn sanitize_function_name(name: &str) -> String { let re = Regex::new(r"[^a-zA-Z0-9_-]").unwrap(); re.replace_all(name, "_").to_string() } -fn is_valid_function_name(name: &str) -> bool { +pub fn is_valid_function_name(name: &str) -> bool { let re = Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap(); re.is_match(name) } diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index 40d1ab09a0ed..c6f2e7e9a5f0 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -13,6 +13,7 @@ pub struct TokenCounter { const GPT_4O_TOKENIZER_KEY: &str = "Xenova--gpt-4o"; const CLAUDE_TOKENIZER_KEY: &str = "Xenova--claude-tokenizer"; +const GOOGLE_TOKENIZER_KEY: &str = "Xenova--google-tokenizer"; const QWEN_TOKENIZER_KEY: &str = "Qwen--Qwen2.5-Coder-32B-Instruct"; impl Default for TokenCounter { @@ -48,7 +49,7 @@ impl TokenCounter { tokenizers: HashMap::new(), }; // Add default tokenizers - for tokenizer_key in [GPT_4O_TOKENIZER_KEY, CLAUDE_TOKENIZER_KEY] { + for tokenizer_key in [GPT_4O_TOKENIZER_KEY, CLAUDE_TOKENIZER_KEY, GOOGLE_TOKENIZER_KEY] { counter.load_tokenizer(tokenizer_key); } counter @@ -60,10 +61,14 @@ impl TokenCounter { fn model_to_tokenizer_key(model_name: Option<&str>) -> &str { let model_name = model_name.unwrap_or("gpt-4o").to_lowercase(); + // Lifei: to remove + println!("Model name: {}", model_name); if model_name.contains("claude") { CLAUDE_TOKENIZER_KEY } else if model_name.contains("qwen") { QWEN_TOKENIZER_KEY + } else if model_name.contains("gemini") { + GOOGLE_TOKENIZER_KEY } else { // default GPT_4O_TOKENIZER_KEY From 0804910b413f6e7e9bd49f3dd9f5756b1e1975a5 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 17 Dec 2024 13:39:29 +1100 Subject: [PATCH 02/17] cleaned up parameter payload --- crates/goose/src/providers/google.rs | 61 +++++++++++++++++++++++----- crates/goose/src/token_counter.rs | 3 +- download_tokenizers.sh | 3 +- 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index cf248bacac10..45dfe84cf78f 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -104,7 +104,7 @@ impl Provider for GoogleProvider { } // Make request - let response = self.post(serde_json::Value::Object(payload)).await?; + let response = self.post(Value::Object(payload)).await?; // Lifei: TODO handle api errors https://ai.google.dev/gemini-api/docs/troubleshooting?lang=python // // Raise specific error if context length is exceeded @@ -146,11 +146,13 @@ fn messages_to_google_spec(messages: &[Message]) -> Vec { } MessageContent::ToolRequest(request) => match &request.tool_call { Ok(tool_call) => { + let mut function_call_part = Map::new(); + function_call_part.insert("name".to_string(), json!(tool_call.name)); + if tool_call.arguments.is_object() && !tool_call.arguments.as_object().unwrap().is_empty() { + function_call_part.insert("arguments".to_string(), tool_call.arguments.clone()); + } parts.push(json!({ - "functionCall": { - "name": tool_call.name, - "arguments": tool_call.arguments, - } + "functionCall": function_call_part })); } Err(e) => { @@ -204,15 +206,54 @@ fn tools_to_google_spec(tools: &[Tool]) -> Vec { tools .iter() .map(|tool| { - json!({ - "name": tool.name, - "description": tool.description, - "parameters": tool.input_schema, - }) + let mut parameters = Map::new(); + parameters.insert("name".to_string(), json!(tool.name)); + parameters.insert("description".to_string(), json!(tool.description)); + let tool_input_schema = tool.input_schema.as_object().unwrap(); + let tool_input_schema_properties = tool_input_schema.get("properties").unwrap_or(&json!({})).as_object().unwrap().clone(); + if !tool_input_schema_properties.is_empty() { + let accepted_tool_schema_attributes = vec!["type".to_string(), "format".to_string(), "description".to_string(), "nullable".to_string(), "enum".to_string(), "maxItems".to_string(), "minItems".to_string(), "properties".to_string(), "required".to_string(), "items".to_string()]; + parameters.insert("parameters".to_string(), json!(process_map(tool_input_schema, &accepted_tool_schema_attributes, None))); + } + json!(parameters) }) .collect() } +fn process_map( + map: &Map, + accepted_keys: &[String], + parent_key: Option<&str>, // Track the parent key +) -> Value { + let mut filtered_map: Map = map + .iter() + .filter_map(|(key, value)| { + let should_remove = !accepted_keys.contains(key) && parent_key != Some("properties"); + if should_remove { + return None; + } + // Process nested maps recursively + let filtered_value = match value { + Value::Object(nested_map) => { + process_map( + &nested_map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(), + accepted_keys, + Some(key), + ) + } + _ => value.clone(), + }; + + Some((key.clone(), filtered_value)) + }) + .collect(); + if parent_key != Some("properties") && !filtered_map.contains_key("type") { + filtered_map.insert("type".to_string(), Value::String("string".to_string())); + } + + Value::Object(filtered_map) +} + fn google_response_to_message(response: Value) -> anyhow::Result { let mut content = Vec::new(); let binding = vec![]; diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index c6f2e7e9a5f0..3d36ea906ea7 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -13,7 +13,7 @@ pub struct TokenCounter { const GPT_4O_TOKENIZER_KEY: &str = "Xenova--gpt-4o"; const CLAUDE_TOKENIZER_KEY: &str = "Xenova--claude-tokenizer"; -const GOOGLE_TOKENIZER_KEY: &str = "Xenova--google-tokenizer"; +const GOOGLE_TOKENIZER_KEY: &str = "Xenova--gemini-nano"; const QWEN_TOKENIZER_KEY: &str = "Qwen--Qwen2.5-Coder-32B-Instruct"; impl Default for TokenCounter { @@ -62,7 +62,6 @@ impl TokenCounter { fn model_to_tokenizer_key(model_name: Option<&str>) -> &str { let model_name = model_name.unwrap_or("gpt-4o").to_lowercase(); // Lifei: to remove - println!("Model name: {}", model_name); if model_name.contains("claude") { CLAUDE_TOKENIZER_KEY } else if model_name.contains("qwen") { diff --git a/download_tokenizers.sh b/download_tokenizers.sh index 6658194d4eda..4586539381f9 100755 --- a/download_tokenizers.sh +++ b/download_tokenizers.sh @@ -31,4 +31,5 @@ download_tokenizer() { # Download tokenizers for each model download_tokenizer "Xenova/gpt-4o" download_tokenizer "Xenova/claude-tokenizer" -download_tokenizer "Qwen/Qwen2.5-Coder-32B-Instruct" \ No newline at end of file +download_tokenizer "Qwen/Qwen2.5-Coder-32B-Instruct" +download_tokenizer "Xenova/gemini-nano" From af30e0f77d8516c5d0cb86fcb2ef584324c2204a Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 17 Dec 2024 13:41:17 +1100 Subject: [PATCH 03/17] reformat code --- crates/goose-cli/src/profile.rs | 5 +- crates/goose-server/src/configuration.rs | 4 +- crates/goose/src/providers.rs | 2 +- crates/goose/src/providers/configs.rs | 1 - crates/goose/src/providers/factory.rs | 6 +- crates/goose/src/providers/google.rs | 139 ++++++++++++++++------- crates/goose/src/token_counter.rs | 6 +- 7 files changed, 112 insertions(+), 51 deletions(-) diff --git a/crates/goose-cli/src/profile.rs b/crates/goose-cli/src/profile.rs index 8bc4f521362e..6ab4d4424f87 100644 --- a/crates/goose-cli/src/profile.rs +++ b/crates/goose-cli/src/profile.rs @@ -1,6 +1,9 @@ use anyhow::Result; use goose::key_manager::{get_keyring_secret, KeyRetrievalStrategy}; -use goose::providers::configs::{AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, GoogleProviderConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig}; +use goose::providers::configs::{ + AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, GoogleProviderConfig, + OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig, +}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fs; diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index a1fb929be130..b2785e8297e6 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -1,5 +1,6 @@ use crate::error::{to_env_var, ConfigError}; use config::{Config, Environment}; +use goose::providers::configs::GoogleProviderConfig; use goose::providers::{ configs::{ DatabricksAuth, DatabricksProviderConfig, OllamaProviderConfig, OpenAiProviderConfig, @@ -11,7 +12,6 @@ use goose::providers::{ }; use serde::Deserialize; use std::net::SocketAddr; -use goose::providers::configs::GoogleProviderConfig; #[derive(Debug, Default, Deserialize)] pub struct ServerSettings { @@ -131,7 +131,7 @@ impl ProviderSettings { temperature, max_tokens, }), - ProviderSettings::Google{ + ProviderSettings::Google { host, api_key, model, diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index 1839a8c54d27..a9eeb108b92a 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -8,6 +8,6 @@ pub mod ollama; pub mod openai; pub mod utils; +mod google; #[cfg(test)] pub mod mock; -mod google; diff --git a/crates/goose/src/providers/configs.rs b/crates/goose/src/providers/configs.rs index 04c00b6d6141..678a037879c6 100644 --- a/crates/goose/src/providers/configs.rs +++ b/crates/goose/src/providers/configs.rs @@ -91,7 +91,6 @@ pub struct GoogleProviderConfig { pub max_tokens: Option, } - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OllamaProviderConfig { pub host: String, diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index f34fe22dee54..832d5fcef7ed 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -2,9 +2,9 @@ use super::{ anthropic::AnthropicProvider, base::Provider, configs::ProviderConfig, databricks::DatabricksProvider, ollama::OllamaProvider, openai::OpenAiProvider, }; +use crate::providers::google::GoogleProvider; use anyhow::Result; use strum_macros::EnumIter; -use crate::providers::google::GoogleProvider; #[derive(EnumIter, Debug)] pub enum ProviderType { @@ -25,8 +25,6 @@ pub fn get_provider(config: ProviderConfig) -> Result { Ok(Box::new(AnthropicProvider::new(anthropic_config)?)) } - ProviderConfig::Google(google_config) => { - Ok(Box::new(GoogleProvider::new(google_config)?)) - } + ProviderConfig::Google(google_config) => Ok(Box::new(GoogleProvider::new(google_config)?)), } } diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 45dfe84cf78f..b7abcdfe1d85 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -1,14 +1,14 @@ -use std::time::Duration; -use anyhow::anyhow; -use async_trait::async_trait; -use reqwest::{Client, StatusCode}; -use serde_json::{json, Value, Map}; -use mcp_core::{Content, Role, Tool, ToolCall}; use crate::errors::AgentError; use crate::message::{Message, MessageContent}; use crate::providers::base::{Provider, ProviderUsageCollector, Usage}; use crate::providers::configs::GoogleProviderConfig; -use crate::providers::utils::{is_valid_function_name}; +use crate::providers::utils::is_valid_function_name; +use anyhow::anyhow; +use async_trait::async_trait; +use mcp_core::{Content, Role, Tool, ToolCall}; +use reqwest::{Client, StatusCode}; +use serde_json::{json, Map, Value}; +use std::time::Duration; pub struct GoogleProvider { client: Client, @@ -97,10 +97,19 @@ impl Provider for GoogleProvider { ) -> anyhow::Result<(Message, Usage)> { // Lifei: TODO: temperature parameters, tools may be empty, images let mut payload = Map::new(); - payload.insert("system_instruction".to_string(), json!({"parts": [{"text": system}]})); - payload.insert("contents".to_string(), json!(messages_to_google_spec(&messages))); + payload.insert( + "system_instruction".to_string(), + json!({"parts": [{"text": system}]}), + ); + payload.insert( + "contents".to_string(), + json!(messages_to_google_spec(&messages)), + ); if !tools.is_empty() { - payload.insert("tools".to_string(), json!({"functionDeclarations": tools_to_google_spec(&tools)})); + payload.insert( + "tools".to_string(), + json!({"functionDeclarations": tools_to_google_spec(&tools)}), + ); } // Make request @@ -124,8 +133,6 @@ impl Provider for GoogleProvider { Ok((message, usage)) } - - fn total_usage(&self) -> Usage { self.usage_collector.get_usage() } @@ -135,7 +142,11 @@ fn messages_to_google_spec(messages: &[Message]) -> Vec { messages .iter() .map(|message| { - let role = if message.role == Role::User { "user" } else { "model" }; + let role = if message.role == Role::User { + "user" + } else { + "model" + }; let mut parts = Vec::new(); for message_content in message.content.iter() { match message_content { @@ -148,8 +159,11 @@ fn messages_to_google_spec(messages: &[Message]) -> Vec { Ok(tool_call) => { let mut function_call_part = Map::new(); function_call_part.insert("name".to_string(), json!(tool_call.name)); - if tool_call.arguments.is_object() && !tool_call.arguments.as_object().unwrap().is_empty() { - function_call_part.insert("arguments".to_string(), tool_call.arguments.clone()); + if tool_call.arguments.is_object() + && !tool_call.arguments.as_object().unwrap().is_empty() + { + function_call_part + .insert("arguments".to_string(), tool_call.arguments.clone()); } parts.push(json!({ "functionCall": function_call_part @@ -158,7 +172,7 @@ fn messages_to_google_spec(messages: &[Message]) -> Vec { Err(e) => { parts.push(json!({"text":format!("Error: {}", e)})); } - } + }, MessageContent::ToolResponse(response) => { match &response.tool_result { Ok(contents) => { @@ -166,17 +180,16 @@ fn messages_to_google_spec(messages: &[Message]) -> Vec { let abridged: Vec<_> = contents .iter() .filter(|content| { - content - .audience() - .is_none_or(|audience| audience.contains(&Role::Assistant)) + content.audience().is_none_or(|audience| { + audience.contains(&Role::Assistant) + }) }) .map(|content| content.unannotated()) .collect(); for content in abridged { match content { - Content::Image(image) => { - } + Content::Image(image) => {} _ => { parts.push(json!({ "functionResponse": { @@ -202,7 +215,7 @@ fn messages_to_google_spec(messages: &[Message]) -> Vec { .collect() } -fn tools_to_google_spec(tools: &[Tool]) -> Vec { +fn tools_to_google_spec(tools: &[Tool]) -> Vec { tools .iter() .map(|tool| { @@ -210,10 +223,33 @@ fn tools_to_google_spec(tools: &[Tool]) -> Vec { parameters.insert("name".to_string(), json!(tool.name)); parameters.insert("description".to_string(), json!(tool.description)); let tool_input_schema = tool.input_schema.as_object().unwrap(); - let tool_input_schema_properties = tool_input_schema.get("properties").unwrap_or(&json!({})).as_object().unwrap().clone(); + let tool_input_schema_properties = tool_input_schema + .get("properties") + .unwrap_or(&json!({})) + .as_object() + .unwrap() + .clone(); if !tool_input_schema_properties.is_empty() { - let accepted_tool_schema_attributes = vec!["type".to_string(), "format".to_string(), "description".to_string(), "nullable".to_string(), "enum".to_string(), "maxItems".to_string(), "minItems".to_string(), "properties".to_string(), "required".to_string(), "items".to_string()]; - parameters.insert("parameters".to_string(), json!(process_map(tool_input_schema, &accepted_tool_schema_attributes, None))); + let accepted_tool_schema_attributes = vec![ + "type".to_string(), + "format".to_string(), + "description".to_string(), + "nullable".to_string(), + "enum".to_string(), + "maxItems".to_string(), + "minItems".to_string(), + "properties".to_string(), + "required".to_string(), + "items".to_string(), + ]; + parameters.insert( + "parameters".to_string(), + json!(process_map( + tool_input_schema, + &accepted_tool_schema_attributes, + None + )), + ); } json!(parameters) }) @@ -234,13 +270,14 @@ fn process_map( } // Process nested maps recursively let filtered_value = match value { - Value::Object(nested_map) => { - process_map( - &nested_map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(), - accepted_keys, - Some(key), - ) - } + Value::Object(nested_map) => process_map( + &nested_map + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(), + accepted_keys, + Some(key), + ), _ => value.clone(), }; @@ -257,22 +294,38 @@ fn process_map( fn google_response_to_message(response: Value) -> anyhow::Result { let mut content = Vec::new(); let binding = vec![]; - let candidates: &Vec = response.get("candidates").and_then(|v| v.as_array()).unwrap_or(&binding); + let candidates: &Vec = response + .get("candidates") + .and_then(|v| v.as_array()) + .unwrap_or(&binding); let candidate = candidates.get(0); let role = Role::Assistant; let created = chrono::Utc::now().timestamp(); if candidate.is_none() { - return Ok(Message { role, created, content}); + return Ok(Message { + role, + created, + content, + }); } let candidate = candidate.unwrap(); - let parts = candidate.get("content") - .and_then(|content| content.get("parts")).and_then(|parts| parts.as_array()).unwrap_or(&binding); + let parts = candidate + .get("content") + .and_then(|content| content.get("parts")) + .and_then(|parts| parts.as_array()) + .unwrap_or(&binding); for part in parts { if let Some(text) = part.get("text").and_then(|v| v.as_str()) { content.push(MessageContent::text(text.to_string())); } else if let Some(function_call) = part.get("functionCall") { - let id = function_call["name"].as_str().unwrap_or_default().to_string(); - let name = function_call["name"].as_str().unwrap_or_default().to_string(); + let id = function_call["name"] + .as_str() + .unwrap_or_default() + .to_string(); + let name = function_call["name"] + .as_str() + .unwrap_or_default() + .to_string(); if !is_valid_function_name(&name) { let error = AgentError::ToolNotFound(format!( "The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+", @@ -284,11 +337,15 @@ fn google_response_to_message(response: Value) -> anyhow::Result { if parameters.is_some() { content.push(MessageContent::tool_request( id, - Ok(ToolCall::new(&name, parameters.unwrap().clone())))); + Ok(ToolCall::new(&name, parameters.unwrap().clone())), + )); } } - } } - Ok(Message { role, created, content}) + Ok(Message { + role, + created, + content, + }) } diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index 3d36ea906ea7..b47f95eadc7c 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -49,7 +49,11 @@ impl TokenCounter { tokenizers: HashMap::new(), }; // Add default tokenizers - for tokenizer_key in [GPT_4O_TOKENIZER_KEY, CLAUDE_TOKENIZER_KEY, GOOGLE_TOKENIZER_KEY] { + for tokenizer_key in [ + GPT_4O_TOKENIZER_KEY, + CLAUDE_TOKENIZER_KEY, + GOOGLE_TOKENIZER_KEY, + ] { counter.load_tokenizer(tokenizer_key); } counter From f3f3ac9ac3b751b5484ae00d8d05f814d2460978 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 17 Dec 2024 16:08:39 +1100 Subject: [PATCH 04/17] fixed the double escape string in the response --- crates/goose/src/providers/google.rs | 45 ++++++++++++++++++++-------- crates/goose/src/token_counter.rs | 1 - 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index b7abcdfe1d85..2158b0f7ec4c 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -115,17 +115,8 @@ impl Provider for GoogleProvider { // Make request let response = self.post(Value::Object(payload)).await?; - // Lifei: TODO handle api errors https://ai.google.dev/gemini-api/docs/troubleshooting?lang=python - // // Raise specific error if context length is exceeded - // if let Some(error) = response.get("error") { - // if let Some(err) = check_openai_context_length_error(error) { - // return Err(err.into()); - // } - // return Err(anyhow!("OpenAI API error: {}", error)); - // } - // Parse response - let message = google_response_to_message(response.clone())?; + let message = google_response_to_message(unescape_json_values(&response))?; let usage = Usage::new(Some(100), Some(100), Some(100)); // let usage = Self::get_usage(&response)?; // self.usage_collector.add_usage(usage.clone()); @@ -163,7 +154,7 @@ fn messages_to_google_spec(messages: &[Message]) -> Vec { && !tool_call.arguments.as_object().unwrap().is_empty() { function_call_part - .insert("arguments".to_string(), tool_call.arguments.clone()); + .insert("args".to_string(), tool_call.arguments.clone()); } parts.push(json!({ "functionCall": function_call_part @@ -333,7 +324,7 @@ fn google_response_to_message(response: Value) -> anyhow::Result { )); content.push(MessageContent::tool_request(id, Err(error))); } else { - let parameters = function_call.get("arguments"); + let parameters = function_call.get("args"); if parameters.is_some() { content.push(MessageContent::tool_request( id, @@ -349,3 +340,33 @@ fn google_response_to_message(response: Value) -> anyhow::Result { content, }) } + +fn unescape_json_values(value: &Value) -> Value { + match value { + Value::Object(map) => { + let new_map: Map = map + .iter() + .map(|(k, v)| (k.clone(), unescape_json_values(v))) // Process each value + .collect(); + Value::Object(new_map) + } + Value::Array(arr) => { + let new_array: Vec = arr.iter() + .map(|v| unescape_json_values(v)) + .collect(); + Value::Array(new_array) + } + Value::String(s) => { + let unescaped = s.replace("\\\\n", "\n") + .replace("\\\\t", "\t") + .replace("\\\\r", "\r") + .replace("\\\\\"", "\"") + .replace("\\n", "\n") + .replace("\\t", "\t") + .replace("\\r", "\r") + .replace("\\\"", "\""); + Value::String(unescaped) + } + _ => value.clone(), + } +} \ No newline at end of file diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index b47f95eadc7c..ef279465c63b 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -65,7 +65,6 @@ impl TokenCounter { fn model_to_tokenizer_key(model_name: Option<&str>) -> &str { let model_name = model_name.unwrap_or("gpt-4o").to_lowercase(); - // Lifei: to remove if model_name.contains("claude") { CLAUDE_TOKENIZER_KEY } else if model_name.contains("qwen") { From c205d244bc9c2214b7a4c641de4158a6f5a2efc8 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 17 Dec 2024 16:10:15 +1100 Subject: [PATCH 05/17] format --- crates/goose/src/providers/google.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 2158b0f7ec4c..77f27b53adbf 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -117,6 +117,7 @@ impl Provider for GoogleProvider { // Parse response let message = google_response_to_message(unescape_json_values(&response))?; + // Lifei: TODO Usage let usage = Usage::new(Some(100), Some(100), Some(100)); // let usage = Self::get_usage(&response)?; // self.usage_collector.add_usage(usage.clone()); @@ -351,13 +352,12 @@ fn unescape_json_values(value: &Value) -> Value { Value::Object(new_map) } Value::Array(arr) => { - let new_array: Vec = arr.iter() - .map(|v| unescape_json_values(v)) - .collect(); + let new_array: Vec = arr.iter().map(|v| unescape_json_values(v)).collect(); Value::Array(new_array) } Value::String(s) => { - let unescaped = s.replace("\\\\n", "\n") + let unescaped = s + .replace("\\\\n", "\n") .replace("\\\\t", "\t") .replace("\\\\r", "\r") .replace("\\\\\"", "\"") @@ -369,4 +369,4 @@ fn unescape_json_values(value: &Value) -> Value { } _ => value.clone(), } -} \ No newline at end of file +} From 09dfe71c8d4cd9324325ddd8362e1dba841fc572 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 17 Dec 2024 16:28:39 +1100 Subject: [PATCH 06/17] fixed merege conflicts --- crates/goose-cli/src/profile.rs | 15 +++++++++++---- crates/goose-server/src/configuration.rs | 4 +--- crates/goose-server/src/state.rs | 2 -- crates/goose/src/providers/configs.rs | 10 +++++++--- crates/goose/src/providers/google.rs | 19 +++++++++---------- 5 files changed, 28 insertions(+), 22 deletions(-) diff --git a/crates/goose-cli/src/profile.rs b/crates/goose-cli/src/profile.rs index d318bb9827d3..6f1b1a7bf038 100644 --- a/crates/goose-cli/src/profile.rs +++ b/crates/goose-cli/src/profile.rs @@ -1,9 +1,6 @@ use anyhow::Result; use goose::key_manager::{get_keyring_secret, KeyRetrievalStrategy}; -use goose::providers::configs::{ - AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, OllamaProviderConfig, ModelConfig, - OpenAiProviderConfig, ProviderConfig, -}; +use goose::providers::configs::{AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, OllamaProviderConfig, ModelConfig, OpenAiProviderConfig, ProviderConfig, GoogleProviderConfig}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fs; @@ -125,6 +122,16 @@ pub fn get_provider_config(provider_name: &str, profile: Profile) -> ProviderCon model: model_config, }) } + "google" => { + let api_key = get_keyring_secret("GOOGLE_API_KEY", KeyRetrievalStrategy::Both) + .expect("GOOGLE_API_KEY not available in env or the keychain\nSet an env var or rerun `goose configure`"); + + ProviderConfig::Google(GoogleProviderConfig { + host: "https://generativelanguage.googleapis.com".to_string(), // Default Anthropic API endpoint + api_key, + model: model_config, + }) + } _ => panic!("Invalid provider name"), } } diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index 19aaace9ed42..682f58e975ea 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -164,9 +164,7 @@ impl ProviderSettings { } => ProviderConfig::Google(GoogleProviderConfig { host, api_key, - model, - temperature, - max_tokens, + model: ModelConfig::new(model) }), } } diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index 909de4c86e03..446c538dcee4 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -69,8 +69,6 @@ impl Clone for AppState { host: config.host.clone(), api_key: config.api_key.clone(), model: config.model.clone(), - temperature: config.temperature, - max_tokens: config.max_tokens, }) } }, diff --git a/crates/goose/src/providers/configs.rs b/crates/goose/src/providers/configs.rs index bd594442e2f1..67c49282dc5f 100644 --- a/crates/goose/src/providers/configs.rs +++ b/crates/goose/src/providers/configs.rs @@ -213,9 +213,13 @@ impl ProviderModelConfig for OpenAiProviderConfig { pub struct GoogleProviderConfig { pub host: String, pub api_key: String, - pub model: String, - pub temperature: Option, - pub max_tokens: Option, + pub model: ModelConfig, +} + +impl ProviderModelConfig for GoogleProviderConfig { + fn model_config(&self) -> &ModelConfig { + &self.model + } } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 77f27b53adbf..3316d31f6dd7 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -1,7 +1,7 @@ use crate::errors::AgentError; use crate::message::{Message, MessageContent}; -use crate::providers::base::{Provider, ProviderUsageCollector, Usage}; -use crate::providers::configs::GoogleProviderConfig; +use crate::providers::base::{Provider, ProviderUsage, Usage}; +use crate::providers::configs::{GoogleProviderConfig, ModelConfig, ProviderModelConfig}; use crate::providers::utils::is_valid_function_name; use anyhow::anyhow; use async_trait::async_trait; @@ -9,11 +9,11 @@ use mcp_core::{Content, Role, Tool, ToolCall}; use reqwest::{Client, StatusCode}; use serde_json::{json, Map, Value}; use std::time::Duration; +use rust_decimal_macros::dec; pub struct GoogleProvider { client: Client, config: GoogleProviderConfig, - usage_collector: ProviderUsageCollector, } impl GoogleProvider { @@ -25,7 +25,6 @@ impl GoogleProvider { Ok(Self { client, config, - usage_collector: ProviderUsageCollector::new(), }) } @@ -60,7 +59,7 @@ impl GoogleProvider { let url = format!( "{}/v1beta/models/{}:generateContent?key={}", self.config.host.trim_end_matches('/'), - self.config.model, + self.config.model.model_name, self.config.api_key ); @@ -94,7 +93,7 @@ impl Provider for GoogleProvider { system: &str, messages: &[Message], tools: &[Tool], - ) -> anyhow::Result<(Message, Usage)> { + ) -> anyhow::Result<(Message, ProviderUsage)> { // Lifei: TODO: temperature parameters, tools may be empty, images let mut payload = Map::new(); payload.insert( @@ -121,12 +120,12 @@ impl Provider for GoogleProvider { let usage = Usage::new(Some(100), Some(100), Some(100)); // let usage = Self::get_usage(&response)?; // self.usage_collector.add_usage(usage.clone()); - - Ok((message, usage)) + let provider_usage = ProviderUsage::new("gpt-4o".to_string(), usage, Some(dec!(0.0))); + Ok((message, provider_usage)) } - fn total_usage(&self) -> Usage { - self.usage_collector.get_usage() + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() } } From 2d9d59a9a42b2e76e6375e2f47020e2219e3e211 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 17 Dec 2024 16:29:28 +1100 Subject: [PATCH 07/17] format --- crates/goose-cli/src/profile.rs | 5 ++++- crates/goose-server/src/configuration.rs | 2 +- crates/goose/src/providers/google.rs | 7 ++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/crates/goose-cli/src/profile.rs b/crates/goose-cli/src/profile.rs index 6f1b1a7bf038..6e03f6b387cc 100644 --- a/crates/goose-cli/src/profile.rs +++ b/crates/goose-cli/src/profile.rs @@ -1,6 +1,9 @@ use anyhow::Result; use goose::key_manager::{get_keyring_secret, KeyRetrievalStrategy}; -use goose::providers::configs::{AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, OllamaProviderConfig, ModelConfig, OpenAiProviderConfig, ProviderConfig, GoogleProviderConfig}; +use goose::providers::configs::{ + AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, GoogleProviderConfig, + ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig, +}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fs; diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index 682f58e975ea..d3a8a0a61cdb 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -164,7 +164,7 @@ impl ProviderSettings { } => ProviderConfig::Google(GoogleProviderConfig { host, api_key, - model: ModelConfig::new(model) + model: ModelConfig::new(model), }), } } diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 3316d31f6dd7..f03b02b7f521 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -7,9 +7,9 @@ use anyhow::anyhow; use async_trait::async_trait; use mcp_core::{Content, Role, Tool, ToolCall}; use reqwest::{Client, StatusCode}; +use rust_decimal_macros::dec; use serde_json::{json, Map, Value}; use std::time::Duration; -use rust_decimal_macros::dec; pub struct GoogleProvider { client: Client, @@ -22,10 +22,7 @@ impl GoogleProvider { .timeout(Duration::from_secs(600)) // 10 minutes timeout .build()?; - Ok(Self { - client, - config, - }) + Ok(Self { client, config }) } fn get_usage(data: &Value) -> anyhow::Result { From 882a548974398a0017a3a6e367cee2065bf2ec35 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 17 Dec 2024 16:38:43 +1100 Subject: [PATCH 08/17] passed the temperature and maxoutput tokens --- crates/goose/src/providers/google.rs | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index f03b02b7f521..c2c118b089ea 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -85,13 +85,17 @@ impl GoogleProvider { #[async_trait] impl Provider for GoogleProvider { + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() + } + async fn complete( &self, system: &str, messages: &[Message], tools: &[Tool], ) -> anyhow::Result<(Message, ProviderUsage)> { - // Lifei: TODO: temperature parameters, tools may be empty, images + // Lifei: TODO: temperature parameters, images let mut payload = Map::new(); payload.insert( "system_instruction".to_string(), @@ -107,6 +111,18 @@ impl Provider for GoogleProvider { json!({"functionDeclarations": tools_to_google_spec(&tools)}), ); } + let mut generation_config = Map::new(); + if let Some(temp) = self.config.model.temperature { + generation_config + .insert("temperature".to_string(), json!(temp)); + } + if let Some(tokens) = self.config.model.max_tokens { + generation_config + .insert("maxOutputTokens".to_string(), json!(tokens)); + } + if !generation_config.is_empty() { + payload.insert("generationConfig".to_string(), json!(generation_config)); + } // Make request let response = self.post(Value::Object(payload)).await?; @@ -120,10 +136,6 @@ impl Provider for GoogleProvider { let provider_usage = ProviderUsage::new("gpt-4o".to_string(), usage, Some(dec!(0.0))); Ok((message, provider_usage)) } - - fn get_model_config(&self) -> &ModelConfig { - self.config.model_config() - } } fn messages_to_google_spec(messages: &[Message]) -> Vec { From 054356cc9a58e28d42cc48d64ca544f2959fb58a Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 17 Dec 2024 17:09:12 +1100 Subject: [PATCH 09/17] added the provider usage --- crates/goose/src/providers/google.rs | 60 +++++++++++++--------------- 1 file changed, 27 insertions(+), 33 deletions(-) diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index c2c118b089ea..e3358d348391 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -7,7 +7,6 @@ use anyhow::anyhow; use async_trait::async_trait; use mcp_core::{Content, Role, Tool, ToolCall}; use reqwest::{Client, StatusCode}; -use rust_decimal_macros::dec; use serde_json::{json, Map, Value}; use std::time::Duration; @@ -25,31 +24,25 @@ impl GoogleProvider { Ok(Self { client, config }) } - fn get_usage(data: &Value) -> anyhow::Result { - let usage = data - .get("usage") - .ok_or_else(|| anyhow!("No usage data in response"))?; - - let input_tokens = usage - .get("prompt_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let output_tokens = usage - .get("completion_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32); - - let total_tokens = usage - .get("total_tokens") - .and_then(|v| v.as_i64()) - .map(|v| v as i32) - .or_else(|| match (input_tokens, output_tokens) { - (Some(input), Some(output)) => Some(input + output), - _ => None, - }); - - Ok(Usage::new(input_tokens, output_tokens, total_tokens)) + fn get_usage(&self, data: &Value) -> anyhow::Result { + if let Some(usage_meta_data) = data.get("usageMetadata") { + let input_tokens = usage_meta_data + .get("promptTokenCount") + .and_then(|v| v.as_u64()) + .map(|v| v as i32); + let output_tokens = usage_meta_data + .get("candidatesTokenCount") + .and_then(|v| v.as_u64()) + .map(|v| v as i32); + let total_tokens = usage_meta_data + .get("totalTokenCount") + .and_then(|v| v.as_u64()) + .map(|v| v as i32); + Ok(Usage::new(input_tokens, output_tokens, total_tokens)) + } else { + // If no usage data, return None for all values + Ok(Usage::new(None, None, None)) + } } async fn post(&self, payload: Value) -> anyhow::Result { @@ -95,7 +88,7 @@ impl Provider for GoogleProvider { messages: &[Message], tools: &[Tool], ) -> anyhow::Result<(Message, ProviderUsage)> { - // Lifei: TODO: temperature parameters, images + // Lifei: TODO: images let mut payload = Map::new(); payload.insert( "system_instruction".to_string(), @@ -126,14 +119,15 @@ impl Provider for GoogleProvider { // Make request let response = self.post(Value::Object(payload)).await?; - // Parse response let message = google_response_to_message(unescape_json_values(&response))?; - // Lifei: TODO Usage - let usage = Usage::new(Some(100), Some(100), Some(100)); - // let usage = Self::get_usage(&response)?; - // self.usage_collector.add_usage(usage.clone()); - let provider_usage = ProviderUsage::new("gpt-4o".to_string(), usage, Some(dec!(0.0))); + let usage = self.get_usage(&response)?; + let model = match response.get("modelVersion") { + Some(model_version) => model_version.as_str().unwrap_or_default().to_string(), + None => self.config.model.model_name.clone(), + }; + let provider_usage = ProviderUsage::new(model, usage, None); + println!("====== Google provider: {:?}", provider_usage); Ok((message, provider_usage)) } } From e96b76d33dc923f3e6c94f276180ba48044e4644 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 17 Dec 2024 17:16:03 +1100 Subject: [PATCH 10/17] moved the function in the class --- crates/goose/src/providers/google.rs | 473 +++++++++++++-------------- crates/goose/src/providers/utils.rs | 31 +- 2 files changed, 252 insertions(+), 252 deletions(-) diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index e3358d348391..f5f6c1b01253 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -2,7 +2,7 @@ use crate::errors::AgentError; use crate::message::{Message, MessageContent}; use crate::providers::base::{Provider, ProviderUsage, Usage}; use crate::providers::configs::{GoogleProviderConfig, ModelConfig, ProviderModelConfig}; -use crate::providers::utils::is_valid_function_name; +use crate::providers::utils::{is_valid_function_name, unescape_json_values}; use anyhow::anyhow; use async_trait::async_trait; use mcp_core::{Content, Role, Tool, ToolCall}; @@ -74,6 +74,221 @@ impl GoogleProvider { )), } } + + fn messages_to_google_spec(&self, messages: &[Message]) -> Vec { + messages + .iter() + .map(|message| { + let role = if message.role == Role::User { + "user" + } else { + "model" + }; + let mut parts = Vec::new(); + for message_content in message.content.iter() { + match message_content { + MessageContent::Text(text) => { + if !text.text.is_empty() { + parts.push(json!({"text": text.text})); + } + } + MessageContent::ToolRequest(request) => match &request.tool_call { + Ok(tool_call) => { + let mut function_call_part = Map::new(); + function_call_part + .insert("name".to_string(), json!(tool_call.name)); + if tool_call.arguments.is_object() + && !tool_call.arguments.as_object().unwrap().is_empty() + { + function_call_part + .insert("args".to_string(), tool_call.arguments.clone()); + } + parts.push(json!({ + "functionCall": function_call_part + })); + } + Err(e) => { + parts.push(json!({"text":format!("Error: {}", e)})); + } + }, + MessageContent::ToolResponse(response) => { + match &response.tool_result { + Ok(contents) => { + // Send only contents with no audience or with Assistant in the audience + let abridged: Vec<_> = contents + .iter() + .filter(|content| { + content.audience().is_none_or(|audience| { + audience.contains(&Role::Assistant) + }) + }) + .map(|content| content.unannotated()) + .collect(); + + for content in abridged { + match content { + Content::Image(image) => {} + _ => { + parts.push(json!({ + "functionResponse": { + "name": response.id, + "response": {"content": content}, + }} + )); + } + } + } + } + Err(e) => { + parts.push(json!({"text":format!("Error: {}", e)})); + } + } + } + + _ => {} + } + } + json!({"role": role, "parts": parts}) + }) + .collect() + } + + fn tools_to_google_spec(&self, tools: &[Tool]) -> Vec { + tools + .iter() + .map(|tool| { + let mut parameters = Map::new(); + parameters.insert("name".to_string(), json!(tool.name)); + parameters.insert("description".to_string(), json!(tool.description)); + let tool_input_schema = tool.input_schema.as_object().unwrap(); + let tool_input_schema_properties = tool_input_schema + .get("properties") + .unwrap_or(&json!({})) + .as_object() + .unwrap() + .clone(); + if !tool_input_schema_properties.is_empty() { + let accepted_tool_schema_attributes = vec![ + "type".to_string(), + "format".to_string(), + "description".to_string(), + "nullable".to_string(), + "enum".to_string(), + "maxItems".to_string(), + "minItems".to_string(), + "properties".to_string(), + "required".to_string(), + "items".to_string(), + ]; + parameters.insert( + "parameters".to_string(), + json!(self.process_map( + tool_input_schema, + &accepted_tool_schema_attributes, + None + )), + ); + } + json!(parameters) + }) + .collect() + } + + fn process_map( + &self, + map: &Map, + accepted_keys: &[String], + parent_key: Option<&str>, // Track the parent key + ) -> Value { + let mut filtered_map: Map = map + .iter() + .filter_map(|(key, value)| { + let should_remove = + !accepted_keys.contains(key) && parent_key != Some("properties"); + if should_remove { + return None; + } + // Process nested maps recursively + let filtered_value = match value { + Value::Object(nested_map) => self.process_map( + &nested_map + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(), + accepted_keys, + Some(key), + ), + _ => value.clone(), + }; + + Some((key.clone(), filtered_value)) + }) + .collect(); + if parent_key != Some("properties") && !filtered_map.contains_key("type") { + filtered_map.insert("type".to_string(), Value::String("string".to_string())); + } + + Value::Object(filtered_map) + } + + fn google_response_to_message(&self, response: Value) -> anyhow::Result { + let mut content = Vec::new(); + let binding = vec![]; + let candidates: &Vec = response + .get("candidates") + .and_then(|v| v.as_array()) + .unwrap_or(&binding); + let candidate = candidates.get(0); + let role = Role::Assistant; + let created = chrono::Utc::now().timestamp(); + if candidate.is_none() { + return Ok(Message { + role, + created, + content, + }); + } + let candidate = candidate.unwrap(); + let parts = candidate + .get("content") + .and_then(|content| content.get("parts")) + .and_then(|parts| parts.as_array()) + .unwrap_or(&binding); + for part in parts { + if let Some(text) = part.get("text").and_then(|v| v.as_str()) { + content.push(MessageContent::text(text.to_string())); + } else if let Some(function_call) = part.get("functionCall") { + let id = function_call["name"] + .as_str() + .unwrap_or_default() + .to_string(); + let name = function_call["name"] + .as_str() + .unwrap_or_default() + .to_string(); + if !is_valid_function_name(&name) { + let error = AgentError::ToolNotFound(format!( + "The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+", + name + )); + content.push(MessageContent::tool_request(id, Err(error))); + } else { + let parameters = function_call.get("args"); + if parameters.is_some() { + content.push(MessageContent::tool_request( + id, + Ok(ToolCall::new(&name, parameters.unwrap().clone())), + )); + } + } + } + } + Ok(Message { + role, + created, + content, + }) + } } #[async_trait] @@ -96,22 +311,20 @@ impl Provider for GoogleProvider { ); payload.insert( "contents".to_string(), - json!(messages_to_google_spec(&messages)), + json!(self.messages_to_google_spec(&messages)), ); if !tools.is_empty() { payload.insert( "tools".to_string(), - json!({"functionDeclarations": tools_to_google_spec(&tools)}), + json!({"functionDeclarations": self.tools_to_google_spec(&tools)}), ); } let mut generation_config = Map::new(); if let Some(temp) = self.config.model.temperature { - generation_config - .insert("temperature".to_string(), json!(temp)); + generation_config.insert("temperature".to_string(), json!(temp)); } if let Some(tokens) = self.config.model.max_tokens { - generation_config - .insert("maxOutputTokens".to_string(), json!(tokens)); + generation_config.insert("maxOutputTokens".to_string(), json!(tokens)); } if !generation_config.is_empty() { payload.insert("generationConfig".to_string(), json!(generation_config)); @@ -120,255 +333,13 @@ impl Provider for GoogleProvider { // Make request let response = self.post(Value::Object(payload)).await?; // Parse response - let message = google_response_to_message(unescape_json_values(&response))?; + let message = self.google_response_to_message(unescape_json_values(&response))?; let usage = self.get_usage(&response)?; - let model = match response.get("modelVersion") { + let model = match response.get("modelVersion") { Some(model_version) => model_version.as_str().unwrap_or_default().to_string(), None => self.config.model.model_name.clone(), }; let provider_usage = ProviderUsage::new(model, usage, None); - println!("====== Google provider: {:?}", provider_usage); Ok((message, provider_usage)) } } - -fn messages_to_google_spec(messages: &[Message]) -> Vec { - messages - .iter() - .map(|message| { - let role = if message.role == Role::User { - "user" - } else { - "model" - }; - let mut parts = Vec::new(); - for message_content in message.content.iter() { - match message_content { - MessageContent::Text(text) => { - if !text.text.is_empty() { - parts.push(json!({"text": text.text})); - } - } - MessageContent::ToolRequest(request) => match &request.tool_call { - Ok(tool_call) => { - let mut function_call_part = Map::new(); - function_call_part.insert("name".to_string(), json!(tool_call.name)); - if tool_call.arguments.is_object() - && !tool_call.arguments.as_object().unwrap().is_empty() - { - function_call_part - .insert("args".to_string(), tool_call.arguments.clone()); - } - parts.push(json!({ - "functionCall": function_call_part - })); - } - Err(e) => { - parts.push(json!({"text":format!("Error: {}", e)})); - } - }, - MessageContent::ToolResponse(response) => { - match &response.tool_result { - Ok(contents) => { - // Send only contents with no audience or with Assistant in the audience - let abridged: Vec<_> = contents - .iter() - .filter(|content| { - content.audience().is_none_or(|audience| { - audience.contains(&Role::Assistant) - }) - }) - .map(|content| content.unannotated()) - .collect(); - - for content in abridged { - match content { - Content::Image(image) => {} - _ => { - parts.push(json!({ - "functionResponse": { - "name": response.id, - "response": {"content": content}, - }} - )); - } - } - } - } - Err(e) => { - parts.push(json!({"text":format!("Error: {}", e)})); - } - } - } - - _ => {} - } - } - json!({"role": role, "parts": parts}) - }) - .collect() -} - -fn tools_to_google_spec(tools: &[Tool]) -> Vec { - tools - .iter() - .map(|tool| { - let mut parameters = Map::new(); - parameters.insert("name".to_string(), json!(tool.name)); - parameters.insert("description".to_string(), json!(tool.description)); - let tool_input_schema = tool.input_schema.as_object().unwrap(); - let tool_input_schema_properties = tool_input_schema - .get("properties") - .unwrap_or(&json!({})) - .as_object() - .unwrap() - .clone(); - if !tool_input_schema_properties.is_empty() { - let accepted_tool_schema_attributes = vec![ - "type".to_string(), - "format".to_string(), - "description".to_string(), - "nullable".to_string(), - "enum".to_string(), - "maxItems".to_string(), - "minItems".to_string(), - "properties".to_string(), - "required".to_string(), - "items".to_string(), - ]; - parameters.insert( - "parameters".to_string(), - json!(process_map( - tool_input_schema, - &accepted_tool_schema_attributes, - None - )), - ); - } - json!(parameters) - }) - .collect() -} - -fn process_map( - map: &Map, - accepted_keys: &[String], - parent_key: Option<&str>, // Track the parent key -) -> Value { - let mut filtered_map: Map = map - .iter() - .filter_map(|(key, value)| { - let should_remove = !accepted_keys.contains(key) && parent_key != Some("properties"); - if should_remove { - return None; - } - // Process nested maps recursively - let filtered_value = match value { - Value::Object(nested_map) => process_map( - &nested_map - .iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect(), - accepted_keys, - Some(key), - ), - _ => value.clone(), - }; - - Some((key.clone(), filtered_value)) - }) - .collect(); - if parent_key != Some("properties") && !filtered_map.contains_key("type") { - filtered_map.insert("type".to_string(), Value::String("string".to_string())); - } - - Value::Object(filtered_map) -} - -fn google_response_to_message(response: Value) -> anyhow::Result { - let mut content = Vec::new(); - let binding = vec![]; - let candidates: &Vec = response - .get("candidates") - .and_then(|v| v.as_array()) - .unwrap_or(&binding); - let candidate = candidates.get(0); - let role = Role::Assistant; - let created = chrono::Utc::now().timestamp(); - if candidate.is_none() { - return Ok(Message { - role, - created, - content, - }); - } - let candidate = candidate.unwrap(); - let parts = candidate - .get("content") - .and_then(|content| content.get("parts")) - .and_then(|parts| parts.as_array()) - .unwrap_or(&binding); - for part in parts { - if let Some(text) = part.get("text").and_then(|v| v.as_str()) { - content.push(MessageContent::text(text.to_string())); - } else if let Some(function_call) = part.get("functionCall") { - let id = function_call["name"] - .as_str() - .unwrap_or_default() - .to_string(); - let name = function_call["name"] - .as_str() - .unwrap_or_default() - .to_string(); - if !is_valid_function_name(&name) { - let error = AgentError::ToolNotFound(format!( - "The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+", - name - )); - content.push(MessageContent::tool_request(id, Err(error))); - } else { - let parameters = function_call.get("args"); - if parameters.is_some() { - content.push(MessageContent::tool_request( - id, - Ok(ToolCall::new(&name, parameters.unwrap().clone())), - )); - } - } - } - } - Ok(Message { - role, - created, - content, - }) -} - -fn unescape_json_values(value: &Value) -> Value { - match value { - Value::Object(map) => { - let new_map: Map = map - .iter() - .map(|(k, v)| (k.clone(), unescape_json_values(v))) // Process each value - .collect(); - Value::Object(new_map) - } - Value::Array(arr) => { - let new_array: Vec = arr.iter().map(|v| unescape_json_values(v)).collect(); - Value::Array(new_array) - } - Value::String(s) => { - let unescaped = s - .replace("\\\\n", "\n") - .replace("\\\\t", "\t") - .replace("\\\\r", "\r") - .replace("\\\\\"", "\"") - .replace("\\n", "\n") - .replace("\\t", "\t") - .replace("\\r", "\r") - .replace("\\\"", "\""); - Value::String(unescaped) - } - _ => value.clone(), - } -} diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index 08dc44f506c9..24fbf22ffc46 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, Result}; use regex::Regex; use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; +use serde_json::{json, Map, Value}; use crate::errors::AgentError; use crate::message::{Message, MessageContent}; @@ -287,6 +287,35 @@ pub fn get_model(data: &Value) -> String { } } +pub fn unescape_json_values(value: &Value) -> Value { + match value { + Value::Object(map) => { + let new_map: Map = map + .iter() + .map(|(k, v)| (k.clone(), unescape_json_values(v))) // Process each value + .collect(); + Value::Object(new_map) + } + Value::Array(arr) => { + let new_array: Vec = arr.iter().map(|v| unescape_json_values(v)).collect(); + Value::Array(new_array) + } + Value::String(s) => { + let unescaped = s + .replace("\\\\n", "\n") + .replace("\\\\t", "\t") + .replace("\\\\r", "\r") + .replace("\\\\\"", "\"") + .replace("\\n", "\n") + .replace("\\t", "\t") + .replace("\\r", "\r") + .replace("\\\"", "\""); + Value::String(unescaped) + } + _ => value.clone(), + } +} + #[cfg(test)] mod tests { use super::*; From 6d053c35d52e9f04e24d4f52c309f87002c69241 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 17 Dec 2024 17:52:18 +1100 Subject: [PATCH 11/17] added the image data from function response in message --- crates/goose/src/providers/google.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index f5f6c1b01253..6202a0fb0b77 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -127,7 +127,14 @@ impl GoogleProvider { for content in abridged { match content { - Content::Image(image) => {} + Content::Image(image) => { + parts.push(json!({ + "inline_data": { + "mime_type": image.mime_type, + "data": image.data, + } + })); + } _ => { parts.push(json!({ "functionResponse": { @@ -303,7 +310,6 @@ impl Provider for GoogleProvider { messages: &[Message], tools: &[Tool], ) -> anyhow::Result<(Message, ProviderUsage)> { - // Lifei: TODO: images let mut payload = Map::new(); payload.insert( "system_instruction".to_string(), From 5797af8bc8e8327e294719bccaf720d6314f7c30 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 17 Dec 2024 18:28:04 +1100 Subject: [PATCH 12/17] sanitise function name --- crates/goose/src/providers/google.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 6202a0fb0b77..f0ce5b2ea46e 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -2,7 +2,9 @@ use crate::errors::AgentError; use crate::message::{Message, MessageContent}; use crate::providers::base::{Provider, ProviderUsage, Usage}; use crate::providers::configs::{GoogleProviderConfig, ModelConfig, ProviderModelConfig}; -use crate::providers::utils::{is_valid_function_name, unescape_json_values}; +use crate::providers::utils::{ + is_valid_function_name, sanitize_function_name, unescape_json_values, +}; use anyhow::anyhow; use async_trait::async_trait; use mcp_core::{Content, Role, Tool, ToolCall}; @@ -95,8 +97,10 @@ impl GoogleProvider { MessageContent::ToolRequest(request) => match &request.tool_call { Ok(tool_call) => { let mut function_call_part = Map::new(); - function_call_part - .insert("name".to_string(), json!(tool_call.name)); + function_call_part.insert( + "name".to_string(), + json!(sanitize_function_name(&tool_call.name)), + ); if tool_call.arguments.is_object() && !tool_call.arguments.as_object().unwrap().is_empty() { From 48eb28565be642a39c6724d83d1b0246e2be9a07 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 17 Dec 2024 20:15:33 +1100 Subject: [PATCH 13/17] fixed compilation error --- crates/goose-cli/src/commands/configure.rs | 2 +- crates/goose-server/src/configuration.rs | 14 ++++++++++++-- crates/goose/src/providers.rs | 2 +- crates/goose/src/providers/google.rs | 3 +++ 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index 100339795ee8..1308513ccf40 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -158,7 +158,7 @@ pub fn get_recommended_model(provider_name: &str) -> &str { "databricks" => "claude-3-5-sonnet-2", "ollama" => OLLAMA_MODEL, "anthropic" => "claude-3-5-sonnet-2", - "google" => "gemini-1.5-flash-8b", + "google" => "gemini-1.5-flash", _ => panic!("Invalid provider name"), } } diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index d3a8a0a61cdb..de47633013a1 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -7,7 +7,7 @@ use goose::providers::{ OpenAiProviderConfig, ProviderConfig, }, factory::ProviderType, - ollama, + google, ollama, utils::ImageFormat, }; use serde::Deserialize; @@ -164,7 +164,9 @@ impl ProviderSettings { } => ProviderConfig::Google(GoogleProviderConfig { host, api_key, - model: ModelConfig::new(model), + model: ModelConfig::new(model) + .with_temperature(temperature) + .with_max_tokens(max_tokens), }), } } @@ -257,6 +259,14 @@ fn default_ollama_model() -> String { ollama::OLLAMA_MODEL.to_string() } +fn default_google_host() -> String { + google::GOOGLE_API_HOST.to_string() +} + +fn default_google_model() -> String { + google::GOOGLE_DEFAULT_MODEL.to_string() +} + fn default_image_format() -> ImageFormat { ImageFormat::Anthropic } diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index ffd72da36814..f2d7758aec67 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -9,6 +9,6 @@ pub mod ollama; pub mod openai; pub mod utils; -mod google; +pub mod google; #[cfg(test)] pub mod mock; diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index f0ce5b2ea46e..0374544207ab 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -12,6 +12,9 @@ use reqwest::{Client, StatusCode}; use serde_json::{json, Map, Value}; use std::time::Duration; +pub const GOOGLE_API_HOST: &str = "https://generativelanguage.googleapis.com"; +pub const GOOGLE_DEFAULT_MODEL: &str = "gemini-1.5-flash"; + pub struct GoogleProvider { client: Client, config: GoogleProviderConfig, From 1b1b8271c249330578c89fc47f54afa7b5d4f65b Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 17 Dec 2024 20:52:17 +1100 Subject: [PATCH 14/17] added downloading tokenizers.json --- download_tokenizer_files.py | 1 + 1 file changed, 1 insertion(+) diff --git a/download_tokenizer_files.py b/download_tokenizer_files.py index 19437d10caae..f25b4a2a6980 100644 --- a/download_tokenizer_files.py +++ b/download_tokenizer_files.py @@ -16,6 +16,7 @@ "Xenova/gpt-4o", "Xenova/claude-tokenizer", "Qwen/Qwen2.5-Coder-32B-Instruct", + "Xenova/gemini-nano", ]: download_dir = BASE_DIR / repo_id.replace("/", "--") _path = hf_hub_download(repo_id, filename="tokenizer.json", local_dir=download_dir) From bf1ba7cb2136285287541edceb0b5d3eb4e1065b Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 17 Dec 2024 22:31:35 +1100 Subject: [PATCH 15/17] added tests --- crates/goose/src/providers/google.rs | 262 +++++++++++++++++++++++++++ crates/goose/src/providers/utils.rs | 50 +++++ 2 files changed, 312 insertions(+) diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 0374544207ab..e5ab052de328 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -356,3 +356,265 @@ impl Provider for GoogleProvider { Ok((message, provider_usage)) } } + +#[cfg(test)] // Only compiles this module when running tests +mod tests { + use super::*; + use crate::errors::AgentResult; + fn set_up_provider() -> GoogleProvider { + let provider_config = GoogleProviderConfig { + host: "dummy_host".to_string(), + api_key: "dummy_key".to_string(), + model: ModelConfig::new("dummy_model".to_string()), + }; + GoogleProvider::new(provider_config).unwrap() + } + + fn set_up_text_message(text: &str, role: Role) -> Message { + Message { + role, + created: 0, + content: vec![MessageContent::text(text.to_string())], + } + } + + fn set_up_tool_request_message(id: &str, tool_call: ToolCall) -> Message { + Message { + role: Role::User, + created: 0, + content: vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))], + } + } + + fn set_up_tool_response_message(id: &str, tool_response: Vec) -> Message { + Message { + role: Role::Assistant, + created: 0, + content: vec![MessageContent::tool_response( + id.to_string(), + Ok(tool_response), + )], + } + } + + fn set_up_tool(name: &str, description: &str, params: Value) -> Tool { + Tool { + name: name.to_string(), + description: description.to_string(), + input_schema: json!({ + "properties": params + }), + } + } + + #[test] + fn test_get_usage() { + let provider = set_up_provider(); + let data = json!({ + "usageMetadata": { + "promptTokenCount": 1, + "candidatesTokenCount": 2, + "totalTokenCount": 3 + } + }); + let usage = provider.get_usage(&data).unwrap(); + assert_eq!(usage.input_tokens, Some(1)); + assert_eq!(usage.output_tokens, Some(2)); + assert_eq!(usage.total_tokens, Some(3)); + } + + #[test] + fn test_message_to_google_spec_text_message() { + let provider = set_up_provider(); + let messages = vec![ + set_up_text_message("Hello", Role::User), + set_up_text_message("World", Role::Assistant), + ]; + let payload = provider.messages_to_google_spec(&messages); + assert_eq!(payload.len(), 2); + assert_eq!(payload[0]["role"], "user"); + assert_eq!(payload[0]["parts"][0]["text"], "Hello"); + assert_eq!(payload[1]["role"], "model"); + assert_eq!(payload[1]["parts"][0]["text"], "World"); + } + + #[test] + fn test_message_to_google_spec_tool_request_message() { + let provider = set_up_provider(); + let arguments = json!({ + "param1": "value1" + }); + let messages = vec![set_up_tool_request_message( + "id", + ToolCall::new("tool_name", json!(arguments)), + )]; + let payload = provider.messages_to_google_spec(&messages); + assert_eq!(payload.len(), 1); + assert_eq!(payload[0]["role"], "user"); + assert_eq!(payload[0]["parts"][0]["functionCall"]["args"], arguments); + } + + #[test] + fn test_message_to_google_spec_tool_result_message() { + let provider = set_up_provider(); + let tool_result: AgentResult> = Ok(vec![Content::text("Hello")]); + let messages = vec![set_up_tool_response_message( + "response_id", + tool_result.unwrap(), + )]; + let payload = provider.messages_to_google_spec(&messages); + assert_eq!(payload.len(), 1); + assert_eq!(payload[0]["role"], "model"); + assert_eq!( + payload[0]["parts"][0]["functionResponse"]["name"], + "response_id" + ); + assert_eq!( + payload[0]["parts"][0]["functionResponse"]["response"]["content"]["text"], + "Hello" + ); + } + + #[test] + fn tools_to_google_spec_with_valid_tools() { + let provider = set_up_provider(); + let params1 = json!({ + "param1": { + "type": "string", + "description": "A parameter", + "field_does_not_accept": ["value1", "value2"] + } + }); + let params2 = json!({ + "param2": { + "type": "string", + "description": "B parameter", + } + }); + let tools = vec![ + set_up_tool("tool1", "description1", params1), + set_up_tool("tool2", "description2", params2), + ]; + let result = provider.tools_to_google_spec(&tools); + assert_eq!(result.len(), 2); + assert_eq!(result[0]["name"], "tool1"); + assert_eq!(result[0]["description"], "description1"); + assert_eq!( + result[0]["parameters"]["properties"], + json!({"param1": json!({ + "type": "string", + "description": "A parameter" + })}) + ); + assert_eq!(result[1]["name"], "tool2"); + assert_eq!(result[1]["description"], "description2"); + assert_eq!( + result[1]["parameters"]["properties"], + json!({"param2": json!({ + "type": "string", + "description": "B parameter" + })}) + ); + } + + #[test] + fn tools_to_google_spec_with_empty_properties() { + let provider = set_up_provider(); + let tools = vec![Tool { + name: "tool1".to_string(), + description: "description1".to_string(), + input_schema: json!({ + "properties": {} + }), + }]; + let result = provider.tools_to_google_spec(&tools); + assert_eq!(result.len(), 1); + assert_eq!(result[0]["name"], "tool1"); + assert_eq!(result[0]["description"], "description1"); + assert!(result[0]["parameters"].get("properties").is_none()); + } + + #[test] + fn google_response_to_message_with_no_candidates() { + let provider = set_up_provider(); + let response = json!({}); + let message = provider.google_response_to_message(response).unwrap(); + assert_eq!(message.role, Role::Assistant); + assert!(message.content.is_empty()); + } + + #[test] + fn google_response_to_message_with_text_part() { + let provider = set_up_provider(); + let response = json!({ + "candidates": [{ + "content": { + "parts": [{ + "text": "Hello, world!" + }] + } + }] + }); + let message = provider.google_response_to_message(response).unwrap(); + assert_eq!(message.role, Role::Assistant); + assert_eq!(message.content.len(), 1); + if let MessageContent::Text(text) = &message.content[0] { + assert_eq!(text.text, "Hello, world!"); + } else { + panic!("Expected text content"); + } + } + + #[test] + fn google_response_to_message_with_invalid_function_name() { + let provider = set_up_provider(); + let response = json!({ + "candidates": [{ + "content": { + "parts": [{ + "functionCall": { + "name": "invalid name!", + "args": {} + } + }] + } + }] + }); + let message = provider.google_response_to_message(response).unwrap(); + assert_eq!(message.role, Role::Assistant); + assert_eq!(message.content.len(), 1); + if let Err(error) = &message.content[0].as_tool_request().unwrap().tool_call { + assert!(matches!(error, AgentError::ToolNotFound(_))); + } else { + panic!("Expected tool request error"); + } + } + + #[test] + fn google_response_to_message_with_valid_function_call() { + let provider = set_up_provider(); + let response = json!({ + "candidates": [{ + "content": { + "parts": [{ + "functionCall": { + "name": "valid_name", + "args": { + "param": "value" + } + } + }] + } + }] + }); + let message = provider.google_response_to_message(response).unwrap(); + assert_eq!(message.role, Role::Assistant); + assert_eq!(message.content.len(), 1); + if let Ok(tool_call) = &message.content[0].as_tool_request().unwrap().tool_call { + assert_eq!(tool_call.name, "valid_name"); + assert_eq!(tool_call.arguments["param"], "value"); + } else { + panic!("Expected valid tool request"); + } + } +} diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index 24fbf22ffc46..f3bd5d8ee516 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -620,4 +620,54 @@ mod tests { let result = check_bedrock_context_length_error(&error); assert!(result.is_none()); } + + #[test] + fn unescape_json_values_with_object() { + let value = json!({"text": "Hello\\nWorld"}); + let unescaped_value = unescape_json_values(&value); + assert_eq!(unescaped_value, json!({"text": "Hello\nWorld"})); + } + + #[test] + fn unescape_json_values_with_array() { + let value = json!(["Hello\\nWorld", "Goodbye\\tWorld"]); + let unescaped_value = unescape_json_values(&value); + assert_eq!(unescaped_value, json!(["Hello\nWorld", "Goodbye\tWorld"])); + } + + #[test] + fn unescape_json_values_with_string() { + let value = json!("Hello\\nWorld"); + let unescaped_value = unescape_json_values(&value); + assert_eq!(unescaped_value, json!("Hello\nWorld")); + } + + #[test] + fn unescape_json_values_with_mixed_content() { + let value = json!({ + "text": "Hello\\nWorld\\\\n!", + "array": ["Goodbye\\tWorld", "See you\\rlater"], + "nested": { + "inner_text": "Inner\\\"Quote\\\"" + } + }); + let unescaped_value = unescape_json_values(&value); + assert_eq!( + unescaped_value, + json!({ + "text": "Hello\nWorld\n!", + "array": ["Goodbye\tWorld", "See you\rlater"], + "nested": { + "inner_text": "Inner\"Quote\"" + } + }) + ); + } + + #[test] + fn unescape_json_values_with_no_escapes() { + let value = json!({"text": "Hello World"}); + let unescaped_value = unescape_json_values(&value); + assert_eq!(unescaped_value, json!({"text": "Hello World"})); + } } From 2c87535c2a6fec23464ded52abb7a9a06d77362c Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 18 Dec 2024 08:59:21 +1100 Subject: [PATCH 16/17] cleanup import --- crates/goose/src/providers/factory.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index 832d5fcef7ed..f5a9c0931dfe 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -1,8 +1,8 @@ use super::{ anthropic::AnthropicProvider, base::Provider, configs::ProviderConfig, - databricks::DatabricksProvider, ollama::OllamaProvider, openai::OpenAiProvider, + databricks::DatabricksProvider, google::GoogleProvider, ollama::OllamaProvider, + openai::OpenAiProvider, }; -use crate::providers::google::GoogleProvider; use anyhow::Result; use strum_macros::EnumIter; From 98bc98d2cbd8648a34381ad68b5557a6acc7bc8a Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 18 Dec 2024 12:28:53 +1100 Subject: [PATCH 17/17] changed the tokenizer --- crates/goose/src/token_counter.rs | 2 +- download_tokenizer_files.py | 2 +- download_tokenizers.sh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index ef279465c63b..0a7a5d1127cc 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -13,7 +13,7 @@ pub struct TokenCounter { const GPT_4O_TOKENIZER_KEY: &str = "Xenova--gpt-4o"; const CLAUDE_TOKENIZER_KEY: &str = "Xenova--claude-tokenizer"; -const GOOGLE_TOKENIZER_KEY: &str = "Xenova--gemini-nano"; +const GOOGLE_TOKENIZER_KEY: &str = "Xenova--gemma-2-tokenizer"; const QWEN_TOKENIZER_KEY: &str = "Qwen--Qwen2.5-Coder-32B-Instruct"; impl Default for TokenCounter { diff --git a/download_tokenizer_files.py b/download_tokenizer_files.py index f25b4a2a6980..7c205d45495c 100644 --- a/download_tokenizer_files.py +++ b/download_tokenizer_files.py @@ -16,7 +16,7 @@ "Xenova/gpt-4o", "Xenova/claude-tokenizer", "Qwen/Qwen2.5-Coder-32B-Instruct", - "Xenova/gemini-nano", + "Xenova/gemma-2-tokenizer", ]: download_dir = BASE_DIR / repo_id.replace("/", "--") _path = hf_hub_download(repo_id, filename="tokenizer.json", local_dir=download_dir) diff --git a/download_tokenizers.sh b/download_tokenizers.sh index 4586539381f9..b46618d2d01b 100755 --- a/download_tokenizers.sh +++ b/download_tokenizers.sh @@ -32,4 +32,4 @@ download_tokenizer() { download_tokenizer "Xenova/gpt-4o" download_tokenizer "Xenova/claude-tokenizer" download_tokenizer "Qwen/Qwen2.5-Coder-32B-Instruct" -download_tokenizer "Xenova/gemini-nano" +download_tokenizer "Xenova/gemma-2-tokenizer"