Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub async fn handle_configure(
("databricks", "Databricks", "Models on AI Gateway"),
("ollama", "Ollama", "Local open source models"),
("anthropic", "Anthropic", "Claude models"),
("google", "Google Gemini", "Gemini models"),
])
.interact()?
.to_string()
Expand Down Expand Up @@ -157,6 +158,7 @@ pub fn get_recommended_model(provider_name: &str) -> &str {
"databricks" => "claude-3-5-sonnet-2",
"ollama" => OLLAMA_MODEL,
"anthropic" => "claude-3-5-sonnet-2",
"google" => "gemini-1.5-flash",
_ => panic!("Invalid provider name"),
}
}
Expand All @@ -167,6 +169,7 @@ pub fn get_required_keys(provider_name: &str) -> Vec<&'static str> {
"databricks" => vec!["DATABRICKS_HOST"],
"ollama" => vec!["OLLAMA_HOST"],
"anthropic" => vec!["ANTHROPIC_API_KEY"], // Removed ANTHROPIC_HOST since we use a fixed endpoint
"google" => vec!["GOOGLE_API_KEY"],
_ => panic!("Invalid provider name"),
}
}
14 changes: 12 additions & 2 deletions crates/goose-cli/src/profile.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use anyhow::Result;
use goose::key_manager::{get_keyring_secret, KeyRetrievalStrategy};
use goose::providers::configs::{
AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, ModelConfig,
OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig,
AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, GoogleProviderConfig,
ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
Expand Down Expand Up @@ -125,6 +125,16 @@ pub fn get_provider_config(provider_name: &str, profile: Profile) -> ProviderCon
model: model_config,
})
}
"google" => {
let api_key = get_keyring_secret("GOOGLE_API_KEY", KeyRetrievalStrategy::Both)
.expect("GOOGLE_API_KEY not available in env or the keychain\nSet an env var or rerun `goose configure`");

ProviderConfig::Google(GoogleProviderConfig {
host: "https://generativelanguage.googleapis.com".to_string(), // Default Anthropic API endpoint
api_key,
model: model_config,
})
}
_ => panic!("Invalid provider name"),
}
}
Expand Down
36 changes: 35 additions & 1 deletion crates/goose-server/src/configuration.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::error::{to_env_var, ConfigError};
use config::{Config, Environment};
use goose::providers::configs::GoogleProviderConfig;
use goose::providers::{
configs::{
DatabricksAuth, DatabricksProviderConfig, ModelConfig, OllamaProviderConfig,
OpenAiProviderConfig, ProviderConfig,
},
factory::ProviderType,
ollama,
google, ollama,
utils::ImageFormat,
};
use serde::Deserialize;
Expand Down Expand Up @@ -76,6 +77,17 @@ pub enum ProviderSettings {
#[serde(default)]
estimate_factor: Option<f32>,
},
Google {
#[serde(default = "default_google_host")]
host: String,
api_key: String,
#[serde(default = "default_google_model")]
model: String,
#[serde(default)]
temperature: Option<f32>,
#[serde(default)]
max_tokens: Option<i32>,
},
}

impl ProviderSettings {
Expand All @@ -86,6 +98,7 @@ impl ProviderSettings {
ProviderSettings::OpenAi { .. } => ProviderType::OpenAi,
ProviderSettings::Databricks { .. } => ProviderType::Databricks,
ProviderSettings::Ollama { .. } => ProviderType::Ollama,
ProviderSettings::Google { .. } => ProviderType::Google,
}
}

Expand Down Expand Up @@ -142,6 +155,19 @@ impl ProviderSettings {
.with_context_limit(context_limit)
.with_estimate_factor(estimate_factor),
}),
ProviderSettings::Google {
host,
api_key,
model,
temperature,
max_tokens,
} => ProviderConfig::Google(GoogleProviderConfig {
host,
api_key,
model: ModelConfig::new(model)
.with_temperature(temperature)
.with_max_tokens(max_tokens),
}),
}
}
}
Expand Down Expand Up @@ -233,6 +259,14 @@ fn default_ollama_model() -> String {
ollama::OLLAMA_MODEL.to_string()
}

fn default_google_host() -> String {
google::GOOGLE_API_HOST.to_string()
}

fn default_google_model() -> String {
google::GOOGLE_DEFAULT_MODEL.to_string()
}

fn default_image_format() -> ImageFormat {
ImageFormat::Anthropic
}
Expand Down
7 changes: 7 additions & 0 deletions crates/goose-server/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ impl Clone for AppState {
model: config.model.clone(),
})
}
ProviderConfig::Google(config) => {
ProviderConfig::Google(goose::providers::configs::GoogleProviderConfig {
host: config.host.clone(),
api_key: config.api_key.clone(),
model: config.model.clone(),
})
}
},
agent: self.agent.clone(),
secret_key: self.secret_key.clone(),
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ pub mod ollama;
pub mod openai;
pub mod utils;

pub mod google;
#[cfg(test)]
pub mod mock;
14 changes: 14 additions & 0 deletions crates/goose/src/providers/configs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub enum ProviderConfig {
Databricks(DatabricksProviderConfig),
Ollama(OllamaProviderConfig),
Anthropic(AnthropicProviderConfig),
Google(GoogleProviderConfig),
}

/// Configuration for model-specific settings and limits
Expand Down Expand Up @@ -208,6 +209,19 @@ impl ProviderModelConfig for OpenAiProviderConfig {
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GoogleProviderConfig {
pub host: String,
pub api_key: String,
pub model: ModelConfig,
}

impl ProviderModelConfig for GoogleProviderConfig {
fn model_config(&self) -> &ModelConfig {
&self.model
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaProviderConfig {
pub host: String,
Expand Down
5 changes: 4 additions & 1 deletion crates/goose/src/providers/factory.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::{
anthropic::AnthropicProvider, base::Provider, configs::ProviderConfig,
databricks::DatabricksProvider, ollama::OllamaProvider, openai::OpenAiProvider,
databricks::DatabricksProvider, google::GoogleProvider, ollama::OllamaProvider,
openai::OpenAiProvider,
};
use anyhow::Result;
use strum_macros::EnumIter;
Expand All @@ -11,6 +12,7 @@ pub enum ProviderType {
Databricks,
Ollama,
Anthropic,
Google,
}

pub fn get_provider(config: ProviderConfig) -> Result<Box<dyn Provider + Send + Sync>> {
Expand All @@ -23,5 +25,6 @@ pub fn get_provider(config: ProviderConfig) -> Result<Box<dyn Provider + Send +
ProviderConfig::Anthropic(anthropic_config) => {
Ok(Box::new(AnthropicProvider::new(anthropic_config)?))
}
ProviderConfig::Google(google_config) => Ok(Box::new(GoogleProvider::new(google_config)?)),
}
}
Loading
Loading