diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index c611e0595601..1308513ccf40 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -47,6 +47,7 @@ pub async fn handle_configure( ("databricks", "Databricks", "Models on AI Gateway"), ("ollama", "Ollama", "Local open source models"), ("anthropic", "Anthropic", "Claude models"), + ("google", "Google Gemini", "Gemini models"), ]) .interact()? .to_string() @@ -157,6 +158,7 @@ pub fn get_recommended_model(provider_name: &str) -> &str { "databricks" => "claude-3-5-sonnet-2", "ollama" => OLLAMA_MODEL, "anthropic" => "claude-3-5-sonnet-2", + "google" => "gemini-1.5-flash", _ => panic!("Invalid provider name"), } } @@ -167,6 +169,7 @@ pub fn get_required_keys(provider_name: &str) -> Vec<&'static str> { "databricks" => vec!["DATABRICKS_HOST"], "ollama" => vec!["OLLAMA_HOST"], "anthropic" => vec!["ANTHROPIC_API_KEY"], // Removed ANTHROPIC_HOST since we use a fixed endpoint + "google" => vec!["GOOGLE_API_KEY"], _ => panic!("Invalid provider name"), } } diff --git a/crates/goose-cli/src/profile.rs b/crates/goose-cli/src/profile.rs index 429932b0e054..6e03f6b387cc 100644 --- a/crates/goose-cli/src/profile.rs +++ b/crates/goose-cli/src/profile.rs @@ -1,8 +1,8 @@ use anyhow::Result; use goose::key_manager::{get_keyring_secret, KeyRetrievalStrategy}; use goose::providers::configs::{ - AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, ModelConfig, - OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig, + AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, GoogleProviderConfig, + ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig, }; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -125,6 +125,16 @@ pub fn get_provider_config(provider_name: &str, profile: Profile) -> ProviderCon model: model_config, }) } + "google" => { + let api_key = get_keyring_secret("GOOGLE_API_KEY", KeyRetrievalStrategy::Both) + .expect("GOOGLE_API_KEY not available in env or the keychain\nSet an env var or rerun `goose configure`"); + + ProviderConfig::Google(GoogleProviderConfig { + host: "https://generativelanguage.googleapis.com".to_string(), // Default Anthropic API endpoint + api_key, + model: model_config, + }) + } _ => panic!("Invalid provider name"), } } diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index 047cb379a36e..de47633013a1 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -1,12 +1,13 @@ use crate::error::{to_env_var, ConfigError}; use config::{Config, Environment}; +use goose::providers::configs::GoogleProviderConfig; use goose::providers::{ configs::{ DatabricksAuth, DatabricksProviderConfig, ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig, }, factory::ProviderType, - ollama, + google, ollama, utils::ImageFormat, }; use serde::Deserialize; @@ -76,6 +77,17 @@ pub enum ProviderSettings { #[serde(default)] estimate_factor: Option, }, + Google { + #[serde(default = "default_google_host")] + host: String, + api_key: String, + #[serde(default = "default_google_model")] + model: String, + #[serde(default)] + temperature: Option, + #[serde(default)] + max_tokens: Option, + }, } impl ProviderSettings { @@ -86,6 +98,7 @@ impl ProviderSettings { ProviderSettings::OpenAi { .. } => ProviderType::OpenAi, ProviderSettings::Databricks { .. } => ProviderType::Databricks, ProviderSettings::Ollama { .. } => ProviderType::Ollama, + ProviderSettings::Google { .. } => ProviderType::Google, } } @@ -142,6 +155,19 @@ impl ProviderSettings { .with_context_limit(context_limit) .with_estimate_factor(estimate_factor), }), + ProviderSettings::Google { + host, + api_key, + model, + temperature, + max_tokens, + } => ProviderConfig::Google(GoogleProviderConfig { + host, + api_key, + model: ModelConfig::new(model) + .with_temperature(temperature) + .with_max_tokens(max_tokens), + }), } } } @@ -233,6 +259,14 @@ fn default_ollama_model() -> String { ollama::OLLAMA_MODEL.to_string() } +fn default_google_host() -> String { + google::GOOGLE_API_HOST.to_string() +} + +fn default_google_model() -> String { + google::GOOGLE_DEFAULT_MODEL.to_string() +} + fn default_image_format() -> ImageFormat { ImageFormat::Anthropic } diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index 9ab997715f88..446c538dcee4 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -64,6 +64,13 @@ impl Clone for AppState { model: config.model.clone(), }) } + ProviderConfig::Google(config) => { + ProviderConfig::Google(goose::providers::configs::GoogleProviderConfig { + host: config.host.clone(), + api_key: config.api_key.clone(), + model: config.model.clone(), + }) + } }, agent: self.agent.clone(), secret_key: self.secret_key.clone(), diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index e60eb851f1d1..f2d7758aec67 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -9,5 +9,6 @@ pub mod ollama; pub mod openai; pub mod utils; +pub mod google; #[cfg(test)] pub mod mock; diff --git a/crates/goose/src/providers/configs.rs b/crates/goose/src/providers/configs.rs index 346892924810..67c49282dc5f 100644 --- a/crates/goose/src/providers/configs.rs +++ b/crates/goose/src/providers/configs.rs @@ -13,6 +13,7 @@ pub enum ProviderConfig { Databricks(DatabricksProviderConfig), Ollama(OllamaProviderConfig), Anthropic(AnthropicProviderConfig), + Google(GoogleProviderConfig), } /// Configuration for model-specific settings and limits @@ -208,6 +209,19 @@ impl ProviderModelConfig for OpenAiProviderConfig { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GoogleProviderConfig { + pub host: String, + pub api_key: String, + pub model: ModelConfig, +} + +impl ProviderModelConfig for GoogleProviderConfig { + fn model_config(&self) -> &ModelConfig { + &self.model + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OllamaProviderConfig { pub host: String, diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index 46f5b3ff8382..f5a9c0931dfe 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -1,6 +1,7 @@ use super::{ anthropic::AnthropicProvider, base::Provider, configs::ProviderConfig, - databricks::DatabricksProvider, ollama::OllamaProvider, openai::OpenAiProvider, + databricks::DatabricksProvider, google::GoogleProvider, ollama::OllamaProvider, + openai::OpenAiProvider, }; use anyhow::Result; use strum_macros::EnumIter; @@ -11,6 +12,7 @@ pub enum ProviderType { Databricks, Ollama, Anthropic, + Google, } pub fn get_provider(config: ProviderConfig) -> Result> { @@ -23,5 +25,6 @@ pub fn get_provider(config: ProviderConfig) -> Result { Ok(Box::new(AnthropicProvider::new(anthropic_config)?)) } + ProviderConfig::Google(google_config) => Ok(Box::new(GoogleProvider::new(google_config)?)), } } diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs new file mode 100644 index 000000000000..e5ab052de328 --- /dev/null +++ b/crates/goose/src/providers/google.rs @@ -0,0 +1,620 @@ +use crate::errors::AgentError; +use crate::message::{Message, MessageContent}; +use crate::providers::base::{Provider, ProviderUsage, Usage}; +use crate::providers::configs::{GoogleProviderConfig, ModelConfig, ProviderModelConfig}; +use crate::providers::utils::{ + is_valid_function_name, sanitize_function_name, unescape_json_values, +}; +use anyhow::anyhow; +use async_trait::async_trait; +use mcp_core::{Content, Role, Tool, ToolCall}; +use reqwest::{Client, StatusCode}; +use serde_json::{json, Map, Value}; +use std::time::Duration; + +pub const GOOGLE_API_HOST: &str = "https://generativelanguage.googleapis.com"; +pub const GOOGLE_DEFAULT_MODEL: &str = "gemini-1.5-flash"; + +pub struct GoogleProvider { + client: Client, + config: GoogleProviderConfig, +} + +impl GoogleProvider { + pub fn new(config: GoogleProviderConfig) -> anyhow::Result { + let client = Client::builder() + .timeout(Duration::from_secs(600)) // 10 minutes timeout + .build()?; + + Ok(Self { client, config }) + } + + fn get_usage(&self, data: &Value) -> anyhow::Result { + if let Some(usage_meta_data) = data.get("usageMetadata") { + let input_tokens = usage_meta_data + .get("promptTokenCount") + .and_then(|v| v.as_u64()) + .map(|v| v as i32); + let output_tokens = usage_meta_data + .get("candidatesTokenCount") + .and_then(|v| v.as_u64()) + .map(|v| v as i32); + let total_tokens = usage_meta_data + .get("totalTokenCount") + .and_then(|v| v.as_u64()) + .map(|v| v as i32); + Ok(Usage::new(input_tokens, output_tokens, total_tokens)) + } else { + // If no usage data, return None for all values + Ok(Usage::new(None, None, None)) + } + } + + async fn post(&self, payload: Value) -> anyhow::Result { + let url = format!( + "{}/v1beta/models/{}:generateContent?key={}", + self.config.host.trim_end_matches('/'), + self.config.model.model_name, + self.config.api_key + ); + + let response = self + .client + .post(&url) + .header("CONTENT_TYPE", "application/json") + .json(&payload) + .send() + .await?; + + match response.status() { + StatusCode::OK => Ok(response.json().await?), + status if status == StatusCode::TOO_MANY_REQUESTS || status.as_u16() >= 500 => { + // Implement retry logic here if needed + Err(anyhow!("Server error: {}", status)) + } + _ => Err(anyhow!( + "Request failed: {}\nPayload: {}", + response.status(), + payload + )), + } + } + + fn messages_to_google_spec(&self, messages: &[Message]) -> Vec { + messages + .iter() + .map(|message| { + let role = if message.role == Role::User { + "user" + } else { + "model" + }; + let mut parts = Vec::new(); + for message_content in message.content.iter() { + match message_content { + MessageContent::Text(text) => { + if !text.text.is_empty() { + parts.push(json!({"text": text.text})); + } + } + MessageContent::ToolRequest(request) => match &request.tool_call { + Ok(tool_call) => { + let mut function_call_part = Map::new(); + function_call_part.insert( + "name".to_string(), + json!(sanitize_function_name(&tool_call.name)), + ); + if tool_call.arguments.is_object() + && !tool_call.arguments.as_object().unwrap().is_empty() + { + function_call_part + .insert("args".to_string(), tool_call.arguments.clone()); + } + parts.push(json!({ + "functionCall": function_call_part + })); + } + Err(e) => { + parts.push(json!({"text":format!("Error: {}", e)})); + } + }, + MessageContent::ToolResponse(response) => { + match &response.tool_result { + Ok(contents) => { + // Send only contents with no audience or with Assistant in the audience + let abridged: Vec<_> = contents + .iter() + .filter(|content| { + content.audience().is_none_or(|audience| { + audience.contains(&Role::Assistant) + }) + }) + .map(|content| content.unannotated()) + .collect(); + + for content in abridged { + match content { + Content::Image(image) => { + parts.push(json!({ + "inline_data": { + "mime_type": image.mime_type, + "data": image.data, + } + })); + } + _ => { + parts.push(json!({ + "functionResponse": { + "name": response.id, + "response": {"content": content}, + }} + )); + } + } + } + } + Err(e) => { + parts.push(json!({"text":format!("Error: {}", e)})); + } + } + } + + _ => {} + } + } + json!({"role": role, "parts": parts}) + }) + .collect() + } + + fn tools_to_google_spec(&self, tools: &[Tool]) -> Vec { + tools + .iter() + .map(|tool| { + let mut parameters = Map::new(); + parameters.insert("name".to_string(), json!(tool.name)); + parameters.insert("description".to_string(), json!(tool.description)); + let tool_input_schema = tool.input_schema.as_object().unwrap(); + let tool_input_schema_properties = tool_input_schema + .get("properties") + .unwrap_or(&json!({})) + .as_object() + .unwrap() + .clone(); + if !tool_input_schema_properties.is_empty() { + let accepted_tool_schema_attributes = vec![ + "type".to_string(), + "format".to_string(), + "description".to_string(), + "nullable".to_string(), + "enum".to_string(), + "maxItems".to_string(), + "minItems".to_string(), + "properties".to_string(), + "required".to_string(), + "items".to_string(), + ]; + parameters.insert( + "parameters".to_string(), + json!(self.process_map( + tool_input_schema, + &accepted_tool_schema_attributes, + None + )), + ); + } + json!(parameters) + }) + .collect() + } + + fn process_map( + &self, + map: &Map, + accepted_keys: &[String], + parent_key: Option<&str>, // Track the parent key + ) -> Value { + let mut filtered_map: Map = map + .iter() + .filter_map(|(key, value)| { + let should_remove = + !accepted_keys.contains(key) && parent_key != Some("properties"); + if should_remove { + return None; + } + // Process nested maps recursively + let filtered_value = match value { + Value::Object(nested_map) => self.process_map( + &nested_map + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(), + accepted_keys, + Some(key), + ), + _ => value.clone(), + }; + + Some((key.clone(), filtered_value)) + }) + .collect(); + if parent_key != Some("properties") && !filtered_map.contains_key("type") { + filtered_map.insert("type".to_string(), Value::String("string".to_string())); + } + + Value::Object(filtered_map) + } + + fn google_response_to_message(&self, response: Value) -> anyhow::Result { + let mut content = Vec::new(); + let binding = vec![]; + let candidates: &Vec = response + .get("candidates") + .and_then(|v| v.as_array()) + .unwrap_or(&binding); + let candidate = candidates.get(0); + let role = Role::Assistant; + let created = chrono::Utc::now().timestamp(); + if candidate.is_none() { + return Ok(Message { + role, + created, + content, + }); + } + let candidate = candidate.unwrap(); + let parts = candidate + .get("content") + .and_then(|content| content.get("parts")) + .and_then(|parts| parts.as_array()) + .unwrap_or(&binding); + for part in parts { + if let Some(text) = part.get("text").and_then(|v| v.as_str()) { + content.push(MessageContent::text(text.to_string())); + } else if let Some(function_call) = part.get("functionCall") { + let id = function_call["name"] + .as_str() + .unwrap_or_default() + .to_string(); + let name = function_call["name"] + .as_str() + .unwrap_or_default() + .to_string(); + if !is_valid_function_name(&name) { + let error = AgentError::ToolNotFound(format!( + "The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+", + name + )); + content.push(MessageContent::tool_request(id, Err(error))); + } else { + let parameters = function_call.get("args"); + if parameters.is_some() { + content.push(MessageContent::tool_request( + id, + Ok(ToolCall::new(&name, parameters.unwrap().clone())), + )); + } + } + } + } + Ok(Message { + role, + created, + content, + }) + } +} + +#[async_trait] +impl Provider for GoogleProvider { + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() + } + + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> anyhow::Result<(Message, ProviderUsage)> { + let mut payload = Map::new(); + payload.insert( + "system_instruction".to_string(), + json!({"parts": [{"text": system}]}), + ); + payload.insert( + "contents".to_string(), + json!(self.messages_to_google_spec(&messages)), + ); + if !tools.is_empty() { + payload.insert( + "tools".to_string(), + json!({"functionDeclarations": self.tools_to_google_spec(&tools)}), + ); + } + let mut generation_config = Map::new(); + if let Some(temp) = self.config.model.temperature { + generation_config.insert("temperature".to_string(), json!(temp)); + } + if let Some(tokens) = self.config.model.max_tokens { + generation_config.insert("maxOutputTokens".to_string(), json!(tokens)); + } + if !generation_config.is_empty() { + payload.insert("generationConfig".to_string(), json!(generation_config)); + } + + // Make request + let response = self.post(Value::Object(payload)).await?; + // Parse response + let message = self.google_response_to_message(unescape_json_values(&response))?; + let usage = self.get_usage(&response)?; + let model = match response.get("modelVersion") { + Some(model_version) => model_version.as_str().unwrap_or_default().to_string(), + None => self.config.model.model_name.clone(), + }; + let provider_usage = ProviderUsage::new(model, usage, None); + Ok((message, provider_usage)) + } +} + +#[cfg(test)] // Only compiles this module when running tests +mod tests { + use super::*; + use crate::errors::AgentResult; + fn set_up_provider() -> GoogleProvider { + let provider_config = GoogleProviderConfig { + host: "dummy_host".to_string(), + api_key: "dummy_key".to_string(), + model: ModelConfig::new("dummy_model".to_string()), + }; + GoogleProvider::new(provider_config).unwrap() + } + + fn set_up_text_message(text: &str, role: Role) -> Message { + Message { + role, + created: 0, + content: vec![MessageContent::text(text.to_string())], + } + } + + fn set_up_tool_request_message(id: &str, tool_call: ToolCall) -> Message { + Message { + role: Role::User, + created: 0, + content: vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))], + } + } + + fn set_up_tool_response_message(id: &str, tool_response: Vec) -> Message { + Message { + role: Role::Assistant, + created: 0, + content: vec![MessageContent::tool_response( + id.to_string(), + Ok(tool_response), + )], + } + } + + fn set_up_tool(name: &str, description: &str, params: Value) -> Tool { + Tool { + name: name.to_string(), + description: description.to_string(), + input_schema: json!({ + "properties": params + }), + } + } + + #[test] + fn test_get_usage() { + let provider = set_up_provider(); + let data = json!({ + "usageMetadata": { + "promptTokenCount": 1, + "candidatesTokenCount": 2, + "totalTokenCount": 3 + } + }); + let usage = provider.get_usage(&data).unwrap(); + assert_eq!(usage.input_tokens, Some(1)); + assert_eq!(usage.output_tokens, Some(2)); + assert_eq!(usage.total_tokens, Some(3)); + } + + #[test] + fn test_message_to_google_spec_text_message() { + let provider = set_up_provider(); + let messages = vec![ + set_up_text_message("Hello", Role::User), + set_up_text_message("World", Role::Assistant), + ]; + let payload = provider.messages_to_google_spec(&messages); + assert_eq!(payload.len(), 2); + assert_eq!(payload[0]["role"], "user"); + assert_eq!(payload[0]["parts"][0]["text"], "Hello"); + assert_eq!(payload[1]["role"], "model"); + assert_eq!(payload[1]["parts"][0]["text"], "World"); + } + + #[test] + fn test_message_to_google_spec_tool_request_message() { + let provider = set_up_provider(); + let arguments = json!({ + "param1": "value1" + }); + let messages = vec![set_up_tool_request_message( + "id", + ToolCall::new("tool_name", json!(arguments)), + )]; + let payload = provider.messages_to_google_spec(&messages); + assert_eq!(payload.len(), 1); + assert_eq!(payload[0]["role"], "user"); + assert_eq!(payload[0]["parts"][0]["functionCall"]["args"], arguments); + } + + #[test] + fn test_message_to_google_spec_tool_result_message() { + let provider = set_up_provider(); + let tool_result: AgentResult> = Ok(vec![Content::text("Hello")]); + let messages = vec![set_up_tool_response_message( + "response_id", + tool_result.unwrap(), + )]; + let payload = provider.messages_to_google_spec(&messages); + assert_eq!(payload.len(), 1); + assert_eq!(payload[0]["role"], "model"); + assert_eq!( + payload[0]["parts"][0]["functionResponse"]["name"], + "response_id" + ); + assert_eq!( + payload[0]["parts"][0]["functionResponse"]["response"]["content"]["text"], + "Hello" + ); + } + + #[test] + fn tools_to_google_spec_with_valid_tools() { + let provider = set_up_provider(); + let params1 = json!({ + "param1": { + "type": "string", + "description": "A parameter", + "field_does_not_accept": ["value1", "value2"] + } + }); + let params2 = json!({ + "param2": { + "type": "string", + "description": "B parameter", + } + }); + let tools = vec![ + set_up_tool("tool1", "description1", params1), + set_up_tool("tool2", "description2", params2), + ]; + let result = provider.tools_to_google_spec(&tools); + assert_eq!(result.len(), 2); + assert_eq!(result[0]["name"], "tool1"); + assert_eq!(result[0]["description"], "description1"); + assert_eq!( + result[0]["parameters"]["properties"], + json!({"param1": json!({ + "type": "string", + "description": "A parameter" + })}) + ); + assert_eq!(result[1]["name"], "tool2"); + assert_eq!(result[1]["description"], "description2"); + assert_eq!( + result[1]["parameters"]["properties"], + json!({"param2": json!({ + "type": "string", + "description": "B parameter" + })}) + ); + } + + #[test] + fn tools_to_google_spec_with_empty_properties() { + let provider = set_up_provider(); + let tools = vec![Tool { + name: "tool1".to_string(), + description: "description1".to_string(), + input_schema: json!({ + "properties": {} + }), + }]; + let result = provider.tools_to_google_spec(&tools); + assert_eq!(result.len(), 1); + assert_eq!(result[0]["name"], "tool1"); + assert_eq!(result[0]["description"], "description1"); + assert!(result[0]["parameters"].get("properties").is_none()); + } + + #[test] + fn google_response_to_message_with_no_candidates() { + let provider = set_up_provider(); + let response = json!({}); + let message = provider.google_response_to_message(response).unwrap(); + assert_eq!(message.role, Role::Assistant); + assert!(message.content.is_empty()); + } + + #[test] + fn google_response_to_message_with_text_part() { + let provider = set_up_provider(); + let response = json!({ + "candidates": [{ + "content": { + "parts": [{ + "text": "Hello, world!" + }] + } + }] + }); + let message = provider.google_response_to_message(response).unwrap(); + assert_eq!(message.role, Role::Assistant); + assert_eq!(message.content.len(), 1); + if let MessageContent::Text(text) = &message.content[0] { + assert_eq!(text.text, "Hello, world!"); + } else { + panic!("Expected text content"); + } + } + + #[test] + fn google_response_to_message_with_invalid_function_name() { + let provider = set_up_provider(); + let response = json!({ + "candidates": [{ + "content": { + "parts": [{ + "functionCall": { + "name": "invalid name!", + "args": {} + } + }] + } + }] + }); + let message = provider.google_response_to_message(response).unwrap(); + assert_eq!(message.role, Role::Assistant); + assert_eq!(message.content.len(), 1); + if let Err(error) = &message.content[0].as_tool_request().unwrap().tool_call { + assert!(matches!(error, AgentError::ToolNotFound(_))); + } else { + panic!("Expected tool request error"); + } + } + + #[test] + fn google_response_to_message_with_valid_function_call() { + let provider = set_up_provider(); + let response = json!({ + "candidates": [{ + "content": { + "parts": [{ + "functionCall": { + "name": "valid_name", + "args": { + "param": "value" + } + } + }] + } + }] + }); + let message = provider.google_response_to_message(response).unwrap(); + assert_eq!(message.role, Role::Assistant); + assert_eq!(message.content.len(), 1); + if let Ok(tool_call) = &message.content[0].as_tool_request().unwrap().tool_call { + assert_eq!(tool_call.name, "valid_name"); + assert_eq!(tool_call.arguments["param"], "value"); + } else { + panic!("Expected valid tool request"); + } + } +} diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index f300dd041405..f3bd5d8ee516 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, Result}; use regex::Regex; use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; +use serde_json::{json, Map, Value}; use crate::errors::AgentError; use crate::message::{Message, MessageContent}; @@ -234,12 +234,12 @@ pub fn openai_response_to_message(response: Value) -> Result { }) } -fn sanitize_function_name(name: &str) -> String { +pub fn sanitize_function_name(name: &str) -> String { let re = Regex::new(r"[^a-zA-Z0-9_-]").unwrap(); re.replace_all(name, "_").to_string() } -fn is_valid_function_name(name: &str) -> bool { +pub fn is_valid_function_name(name: &str) -> bool { let re = Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap(); re.is_match(name) } @@ -287,6 +287,35 @@ pub fn get_model(data: &Value) -> String { } } +pub fn unescape_json_values(value: &Value) -> Value { + match value { + Value::Object(map) => { + let new_map: Map = map + .iter() + .map(|(k, v)| (k.clone(), unescape_json_values(v))) // Process each value + .collect(); + Value::Object(new_map) + } + Value::Array(arr) => { + let new_array: Vec = arr.iter().map(|v| unescape_json_values(v)).collect(); + Value::Array(new_array) + } + Value::String(s) => { + let unescaped = s + .replace("\\\\n", "\n") + .replace("\\\\t", "\t") + .replace("\\\\r", "\r") + .replace("\\\\\"", "\"") + .replace("\\n", "\n") + .replace("\\t", "\t") + .replace("\\r", "\r") + .replace("\\\"", "\""); + Value::String(unescaped) + } + _ => value.clone(), + } +} + #[cfg(test)] mod tests { use super::*; @@ -591,4 +620,54 @@ mod tests { let result = check_bedrock_context_length_error(&error); assert!(result.is_none()); } + + #[test] + fn unescape_json_values_with_object() { + let value = json!({"text": "Hello\\nWorld"}); + let unescaped_value = unescape_json_values(&value); + assert_eq!(unescaped_value, json!({"text": "Hello\nWorld"})); + } + + #[test] + fn unescape_json_values_with_array() { + let value = json!(["Hello\\nWorld", "Goodbye\\tWorld"]); + let unescaped_value = unescape_json_values(&value); + assert_eq!(unescaped_value, json!(["Hello\nWorld", "Goodbye\tWorld"])); + } + + #[test] + fn unescape_json_values_with_string() { + let value = json!("Hello\\nWorld"); + let unescaped_value = unescape_json_values(&value); + assert_eq!(unescaped_value, json!("Hello\nWorld")); + } + + #[test] + fn unescape_json_values_with_mixed_content() { + let value = json!({ + "text": "Hello\\nWorld\\\\n!", + "array": ["Goodbye\\tWorld", "See you\\rlater"], + "nested": { + "inner_text": "Inner\\\"Quote\\\"" + } + }); + let unescaped_value = unescape_json_values(&value); + assert_eq!( + unescaped_value, + json!({ + "text": "Hello\nWorld\n!", + "array": ["Goodbye\tWorld", "See you\rlater"], + "nested": { + "inner_text": "Inner\"Quote\"" + } + }) + ); + } + + #[test] + fn unescape_json_values_with_no_escapes() { + let value = json!({"text": "Hello World"}); + let unescaped_value = unescape_json_values(&value); + assert_eq!(unescaped_value, json!({"text": "Hello World"})); + } } diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index 40d1ab09a0ed..0a7a5d1127cc 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -13,6 +13,7 @@ pub struct TokenCounter { const GPT_4O_TOKENIZER_KEY: &str = "Xenova--gpt-4o"; const CLAUDE_TOKENIZER_KEY: &str = "Xenova--claude-tokenizer"; +const GOOGLE_TOKENIZER_KEY: &str = "Xenova--gemma-2-tokenizer"; const QWEN_TOKENIZER_KEY: &str = "Qwen--Qwen2.5-Coder-32B-Instruct"; impl Default for TokenCounter { @@ -48,7 +49,11 @@ impl TokenCounter { tokenizers: HashMap::new(), }; // Add default tokenizers - for tokenizer_key in [GPT_4O_TOKENIZER_KEY, CLAUDE_TOKENIZER_KEY] { + for tokenizer_key in [ + GPT_4O_TOKENIZER_KEY, + CLAUDE_TOKENIZER_KEY, + GOOGLE_TOKENIZER_KEY, + ] { counter.load_tokenizer(tokenizer_key); } counter @@ -64,6 +69,8 @@ impl TokenCounter { CLAUDE_TOKENIZER_KEY } else if model_name.contains("qwen") { QWEN_TOKENIZER_KEY + } else if model_name.contains("gemini") { + GOOGLE_TOKENIZER_KEY } else { // default GPT_4O_TOKENIZER_KEY diff --git a/download_tokenizer_files.py b/download_tokenizer_files.py index 19437d10caae..7c205d45495c 100644 --- a/download_tokenizer_files.py +++ b/download_tokenizer_files.py @@ -16,6 +16,7 @@ "Xenova/gpt-4o", "Xenova/claude-tokenizer", "Qwen/Qwen2.5-Coder-32B-Instruct", + "Xenova/gemma-2-tokenizer", ]: download_dir = BASE_DIR / repo_id.replace("/", "--") _path = hf_hub_download(repo_id, filename="tokenizer.json", local_dir=download_dir) diff --git a/download_tokenizers.sh b/download_tokenizers.sh index 6658194d4eda..b46618d2d01b 100755 --- a/download_tokenizers.sh +++ b/download_tokenizers.sh @@ -31,4 +31,5 @@ download_tokenizer() { # Download tokenizers for each model download_tokenizer "Xenova/gpt-4o" download_tokenizer "Xenova/claude-tokenizer" -download_tokenizer "Qwen/Qwen2.5-Coder-32B-Instruct" \ No newline at end of file +download_tokenizer "Qwen/Qwen2.5-Coder-32B-Instruct" +download_tokenizer "Xenova/gemma-2-tokenizer"