Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@ use cliclack::spinner;
use console::style;
use goose::key_manager::{get_keyring_secret, save_to_keyring, KeyRetrievalStrategy};
use goose::message::Message;
use goose::providers::anthropic::ANTHROPIC_DEFAULT_MODEL;
use goose::providers::databricks::DATABRICKS_DEFAULT_MODEL;
use goose::providers::factory;
use goose::providers::google::GOOGLE_DEFAULT_MODEL;
use goose::providers::groq::GROQ_DEFAULT_MODEL;
use goose::providers::ollama::OLLAMA_MODEL;
use goose::providers::openai::OPEN_AI_DEFAULT_MODEL;
use std::error::Error;

pub async fn handle_configure(
Expand Down Expand Up @@ -48,6 +53,7 @@ pub async fn handle_configure(
("ollama", "Ollama", "Local open source models"),
("anthropic", "Anthropic", "Claude models"),
("google", "Google Gemini", "Gemini models"),
("groq", "Groq", "AI models"),
])
.interact()?
.to_string()
Expand Down Expand Up @@ -154,11 +160,12 @@ pub async fn handle_configure(

pub fn get_recommended_model(provider_name: &str) -> &str {
match provider_name {
"openai" => "gpt-4o",
"databricks" => "claude-3-5-sonnet-2",
"openai" => OPEN_AI_DEFAULT_MODEL,
"databricks" => DATABRICKS_DEFAULT_MODEL,
"ollama" => OLLAMA_MODEL,
"anthropic" => "claude-3-5-sonnet-2",
"google" => "gemini-1.5-flash",
"anthropic" => ANTHROPIC_DEFAULT_MODEL,
"google" => GOOGLE_DEFAULT_MODEL,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this seems like something we should standardize in the providers for openai, databricks, and anthropic

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"groq" => GROQ_DEFAULT_MODEL,
_ => panic!("Invalid provider name"),
}
}
Expand All @@ -170,6 +177,7 @@ pub fn get_required_keys(provider_name: &str) -> Vec<&'static str> {
"ollama" => vec!["OLLAMA_HOST"],
"anthropic" => vec!["ANTHROPIC_API_KEY"], // Removed ANTHROPIC_HOST since we use a fixed endpoint
"google" => vec!["GOOGLE_API_KEY"],
"groq" => vec!["GROQ_API_KEY"],
_ => panic!("Invalid provider name"),
}
}
14 changes: 12 additions & 2 deletions crates/goose-cli/src/profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use anyhow::Result;
use goose::key_manager::{get_keyring_secret, KeyRetrievalStrategy};
use goose::providers::configs::{
AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, GoogleProviderConfig,
ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig,
GroqProviderConfig, ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
Expand Down Expand Up @@ -130,7 +130,17 @@ pub fn get_provider_config(provider_name: &str, profile: Profile) -> ProviderCon
.expect("GOOGLE_API_KEY not available in env or the keychain\nSet an env var or rerun `goose configure`");

ProviderConfig::Google(GoogleProviderConfig {
host: "https://generativelanguage.googleapis.com".to_string(), // Default Anthropic API endpoint
host: "https://generativelanguage.googleapis.com".to_string(),
api_key,
model: model_config,
})
}
"groq" => {
let api_key = get_keyring_secret("GROQ_API_KEY", KeyRetrievalStrategy::Both)
.expect("GROQ_API_KEY not available in env or the keychain\nSet an env var or rerun `goose configure`");

ProviderConfig::Groq(GroqProviderConfig {
host: "https://api.groq.com".to_string(),
api_key,
model: model_config,
})
Expand Down
40 changes: 37 additions & 3 deletions crates/goose-server/src/configuration.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use crate::error::{to_env_var, ConfigError};
use config::{Config, Environment};
use goose::providers::configs::GoogleProviderConfig;
use goose::providers::configs::{GoogleProviderConfig, GroqProviderConfig};
use goose::providers::openai::OPEN_AI_DEFAULT_MODEL;
use goose::providers::{
configs::{
DatabricksAuth, DatabricksProviderConfig, ModelConfig, OllamaProviderConfig,
OpenAiProviderConfig, ProviderConfig,
},
factory::ProviderType,
google, ollama,
google, groq, ollama,
utils::ImageFormat,
};
use serde::Deserialize;
Expand Down Expand Up @@ -88,6 +89,17 @@ pub enum ProviderSettings {
#[serde(default)]
max_tokens: Option<i32>,
},
Groq {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing I missed in the google provider and here initially was that we also want to add context_limit and estimate_factor here as well

#[serde(default = "default_groq_host")]
host: String,
api_key: String,
#[serde(default = "default_groq_model")]
model: String,
#[serde(default)]
temperature: Option<f32>,
#[serde(default)]
max_tokens: Option<i32>,
},
}

impl ProviderSettings {
Expand All @@ -99,6 +111,7 @@ impl ProviderSettings {
ProviderSettings::Databricks { .. } => ProviderType::Databricks,
ProviderSettings::Ollama { .. } => ProviderType::Ollama,
ProviderSettings::Google { .. } => ProviderType::Google,
ProviderSettings::Groq { .. } => ProviderType::Groq,
}
}

Expand Down Expand Up @@ -168,6 +181,19 @@ impl ProviderSettings {
.with_temperature(temperature)
.with_max_tokens(max_tokens),
}),
ProviderSettings::Groq {
host,
api_key,
model,
temperature,
max_tokens,
} => ProviderConfig::Groq(GroqProviderConfig {
host,
api_key,
model: ModelConfig::new(model)
.with_temperature(temperature)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

context_limit and estimate_factor should go here as well for Groq and Google

.with_max_tokens(max_tokens),
}),
}
}
}
Expand Down Expand Up @@ -240,7 +266,7 @@ fn default_port() -> u16 {
}

fn default_model() -> String {
"gpt-4o".to_string()
OPEN_AI_DEFAULT_MODEL.to_string()
}

fn default_openai_host() -> String {
Expand All @@ -267,6 +293,14 @@ fn default_google_model() -> String {
google::GOOGLE_DEFAULT_MODEL.to_string()
}

fn default_groq_host() -> String {
groq::GROQ_API_HOST.to_string()
}

fn default_groq_model() -> String {
groq::GROQ_DEFAULT_MODEL.to_string()
}

fn default_image_format() -> ImageFormat {
ImageFormat::Anthropic
}
Expand Down
6 changes: 6 additions & 0 deletions crates/goose-server/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use anyhow::Result;
use goose::providers::configs::GroqProviderConfig;
use goose::{
agent::Agent,
developer::DeveloperSystem,
Expand Down Expand Up @@ -71,6 +72,11 @@ impl Clone for AppState {
model: config.model.clone(),
})
}
ProviderConfig::Groq(config) => ProviderConfig::Groq(GroqProviderConfig {
host: config.host.clone(),
api_key: config.api_key.clone(),
model: config.model.clone(),
}),
},
agent: self.agent.clone(),
secret_key: self.secret_key.clone(),
Expand Down
1 change: 1 addition & 0 deletions crates/goose/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const MODELS: &[&str] = &[
"Xenova/gemma-2-tokenizer",
"Xenova/gpt-4o",
"Qwen/Qwen2.5-Coder-32B-Instruct",
"Xenova/llama3-tokenizer",
];

#[tokio::main]
Expand Down
4 changes: 4 additions & 0 deletions crates/goose/src/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ pub mod model_pricing;
pub mod oauth;
pub mod ollama;
pub mod openai;
pub mod openai_utils;
pub mod utils;

pub mod google;
pub mod groq;
#[cfg(test)]
pub mod mock;
#[cfg(test)]
pub mod mock_server;
2 changes: 2 additions & 0 deletions crates/goose/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use mcp_core::content::Content;
use mcp_core::role::Role;
use mcp_core::tool::{Tool, ToolCall};

pub const ANTHROPIC_DEFAULT_MODEL: &str = "claude-3-5-sonnet-latest";

pub struct AnthropicProvider {
client: Client,
config: AnthropicProviderConfig,
Expand Down
14 changes: 14 additions & 0 deletions crates/goose/src/providers/configs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub enum ProviderConfig {
Ollama(OllamaProviderConfig),
Anthropic(AnthropicProviderConfig),
Google(GoogleProviderConfig),
Groq(GroqProviderConfig),
}

/// Configuration for model-specific settings and limits
Expand Down Expand Up @@ -222,6 +223,19 @@ impl ProviderModelConfig for GoogleProviderConfig {
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroqProviderConfig {
pub host: String,
pub api_key: String,
pub model: ModelConfig,
}

impl ProviderModelConfig for GroqProviderConfig {
fn model_config(&self) -> &ModelConfig {
&self.model
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaProviderConfig {
pub host: String,
Expand Down
76 changes: 18 additions & 58 deletions crates/goose/src/providers/databricks.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use reqwest::{Client, StatusCode};
use reqwest::Client;
use serde_json::{json, Value};
use std::time::Duration;

use super::base::{Provider, ProviderUsage, Usage};
use super::configs::{DatabricksAuth, DatabricksProviderConfig, ModelConfig, ProviderModelConfig};
use super::model_pricing::{cost, model_pricing_for};
use super::oauth;
use super::utils::{
check_bedrock_context_length_error, check_openai_context_length_error, get_model,
messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec,
};
use super::utils::{check_bedrock_context_length_error, get_model, handle_response};
use crate::message::Message;
use crate::providers::openai_utils::{
check_openai_context_length_error, get_openai_usage, messages_to_openai_spec,
openai_response_to_message, tools_to_openai_spec,
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should prob use the get_openai_usage and handle_response in this one as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

use mcp_core::tool::Tool;

pub const DATABRICKS_DEFAULT_MODEL: &str = "claude-3-5-sonnet-2";

pub struct DatabricksProvider {
client: Client,
config: DatabricksProviderConfig,
Expand Down Expand Up @@ -46,30 +49,7 @@ impl DatabricksProvider {
}

fn get_usage(data: &Value) -> Result<Usage> {
let usage = data
.get("usage")
.ok_or_else(|| anyhow!("No usage data in response"))?;

let input_tokens = usage
.get("prompt_tokens")
.and_then(|v| v.as_i64())
.map(|v| v as i32);

let output_tokens = usage
.get("completion_tokens")
.and_then(|v| v.as_i64())
.map(|v| v as i32);

let total_tokens = usage
.get("total_tokens")
.and_then(|v| v.as_i64())
.map(|v| v as i32)
.or_else(|| match (input_tokens, output_tokens) {
(Some(input), Some(output)) => Some(input + output),
_ => None,
});

Ok(Usage::new(input_tokens, output_tokens, total_tokens))
get_openai_usage(data)
}

async fn post(&self, payload: Value) -> Result<Value> {
Expand All @@ -88,18 +68,7 @@ impl DatabricksProvider {
.send()
.await?;

match response.status() {
StatusCode::OK => Ok(response.json().await?),
status if status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error() => {
// Implement retry logic here if needed
Err(anyhow!("Server error: {}", status))
}
_ => {
let status = response.status();
let err_text = response.text().await.unwrap_or_default();
Err(anyhow!("Request failed: {}: {}", status, err_text))
}
}
handle_response(payload, response).await?
}
}

Expand All @@ -112,7 +81,7 @@ impl Provider for DatabricksProvider {
tools: &[Tool],
) -> Result<(Message, ProviderUsage)> {
// Prepare messages and tools
let messages_spec = messages_to_openai_spec(messages, &self.config.image_format);
let messages_spec = messages_to_openai_spec(messages, &self.config.image_format, false);
let tools_spec = if !tools.is_empty() {
tools_to_openai_spec(tools)?
} else {
Expand Down Expand Up @@ -179,6 +148,9 @@ mod tests {
use super::*;
use crate::message::MessageContent;
use crate::providers::configs::ModelConfig;
use crate::providers::mock_server::{
create_mock_open_ai_response, TEST_INPUT_TOKENS, TEST_OUTPUT_TOKENS, TEST_TOTAL_TOKENS,
};
use wiremock::matchers::{body_json, header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};

Expand All @@ -188,19 +160,7 @@ mod tests {
let mock_server = MockServer::start().await;

// Mock response for completion
let mock_response = json!({
"choices": [{
"message": {
"role": "assistant",
"content": "Hello!"
}
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 25,
"total_tokens": 35
}
});
let mock_response = create_mock_open_ai_response("my-databricks-model", "Hello!");

// Expected request body
let system = "You are a helpful assistant.";
Expand Down Expand Up @@ -244,9 +204,9 @@ mod tests {
} else {
panic!("Expected Text content");
}
assert_eq!(reply_usage.usage.input_tokens, Some(10));
assert_eq!(reply_usage.usage.output_tokens, Some(25));
assert_eq!(reply_usage.usage.total_tokens, Some(35));
assert_eq!(reply_usage.usage.input_tokens, Some(TEST_INPUT_TOKENS));
assert_eq!(reply_usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS));
assert_eq!(reply_usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS));

Ok(())
}
Expand Down
6 changes: 4 additions & 2 deletions crates/goose/src/providers/factory.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{
anthropic::AnthropicProvider, base::Provider, configs::ProviderConfig,
databricks::DatabricksProvider, google::GoogleProvider, ollama::OllamaProvider,
openai::OpenAiProvider,
databricks::DatabricksProvider, google::GoogleProvider, groq::GroqProvider,
ollama::OllamaProvider, openai::OpenAiProvider,
};
use anyhow::Result;
use strum_macros::EnumIter;
Expand All @@ -13,6 +13,7 @@ pub enum ProviderType {
Ollama,
Anthropic,
Google,
Groq,
}

pub fn get_provider(config: ProviderConfig) -> Result<Box<dyn Provider + Send + Sync>> {
Expand All @@ -26,5 +27,6 @@ pub fn get_provider(config: ProviderConfig) -> Result<Box<dyn Provider + Send +
Ok(Box::new(AnthropicProvider::new(anthropic_config)?))
}
ProviderConfig::Google(google_config) => Ok(Box::new(GoogleProvider::new(google_config)?)),
ProviderConfig::Groq(groq_config) => Ok(Box::new(GroqProvider::new(groq_config)?)),
}
}
Loading
Loading