-
Notifications
You must be signed in to change notification settings - Fork 2.4k
feat: added groq provider #494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3b8ce1d
eef6f3c
4715cb9
07147eb
0a91706
ab705fd
8ce3bdd
4a51019
726471d
38ecbdc
d9d184d
c025802
e3485bf
152ea97
aa6b678
d66e48a
13d8967
5cd0c06
ab5279e
7f4a90f
529f5d1
2b2dad7
0674f30
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,14 @@ | ||
| use crate::error::{to_env_var, ConfigError}; | ||
| use config::{Config, Environment}; | ||
| use goose::providers::configs::GoogleProviderConfig; | ||
| use goose::providers::configs::{GoogleProviderConfig, GroqProviderConfig}; | ||
| use goose::providers::openai::OPEN_AI_DEFAULT_MODEL; | ||
| use goose::providers::{ | ||
| configs::{ | ||
| DatabricksAuth, DatabricksProviderConfig, ModelConfig, OllamaProviderConfig, | ||
| OpenAiProviderConfig, ProviderConfig, | ||
| }, | ||
| factory::ProviderType, | ||
| google, ollama, | ||
| google, groq, ollama, | ||
| utils::ImageFormat, | ||
| }; | ||
| use serde::Deserialize; | ||
|
|
@@ -88,6 +89,17 @@ pub enum ProviderSettings { | |
| #[serde(default)] | ||
| max_tokens: Option<i32>, | ||
| }, | ||
| Groq { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one thing I missed in the google provider and here initially was that we also want to add context_limit and estimate_factor here as well |
||
| #[serde(default = "default_groq_host")] | ||
| host: String, | ||
| api_key: String, | ||
| #[serde(default = "default_groq_model")] | ||
| model: String, | ||
| #[serde(default)] | ||
| temperature: Option<f32>, | ||
| #[serde(default)] | ||
| max_tokens: Option<i32>, | ||
| }, | ||
| } | ||
|
|
||
| impl ProviderSettings { | ||
|
|
@@ -99,6 +111,7 @@ impl ProviderSettings { | |
| ProviderSettings::Databricks { .. } => ProviderType::Databricks, | ||
| ProviderSettings::Ollama { .. } => ProviderType::Ollama, | ||
| ProviderSettings::Google { .. } => ProviderType::Google, | ||
| ProviderSettings::Groq { .. } => ProviderType::Groq, | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -168,6 +181,19 @@ impl ProviderSettings { | |
| .with_temperature(temperature) | ||
| .with_max_tokens(max_tokens), | ||
| }), | ||
| ProviderSettings::Groq { | ||
| host, | ||
| api_key, | ||
| model, | ||
| temperature, | ||
| max_tokens, | ||
| } => ProviderConfig::Groq(GroqProviderConfig { | ||
| host, | ||
| api_key, | ||
| model: ModelConfig::new(model) | ||
| .with_temperature(temperature) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. context_limit and estimate_factor should go here as well for Groq and Google |
||
| .with_max_tokens(max_tokens), | ||
| }), | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -240,7 +266,7 @@ fn default_port() -> u16 { | |
| } | ||
|
|
||
| fn default_model() -> String { | ||
| "gpt-4o".to_string() | ||
| OPEN_AI_DEFAULT_MODEL.to_string() | ||
| } | ||
|
|
||
| fn default_openai_host() -> String { | ||
|
|
@@ -267,6 +293,14 @@ fn default_google_model() -> String { | |
| google::GOOGLE_DEFAULT_MODEL.to_string() | ||
| } | ||
|
|
||
| fn default_groq_host() -> String { | ||
| groq::GROQ_API_HOST.to_string() | ||
| } | ||
|
|
||
| fn default_groq_model() -> String { | ||
| groq::GROQ_DEFAULT_MODEL.to_string() | ||
| } | ||
|
|
||
| fn default_image_format() -> ImageFormat { | ||
| ImageFormat::Anthropic | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,20 +1,23 @@ | ||
| use anyhow::{anyhow, Result}; | ||
| use async_trait::async_trait; | ||
| use reqwest::{Client, StatusCode}; | ||
| use reqwest::Client; | ||
| use serde_json::{json, Value}; | ||
| use std::time::Duration; | ||
|
|
||
| use super::base::{Provider, ProviderUsage, Usage}; | ||
| use super::configs::{DatabricksAuth, DatabricksProviderConfig, ModelConfig, ProviderModelConfig}; | ||
| use super::model_pricing::{cost, model_pricing_for}; | ||
| use super::oauth; | ||
| use super::utils::{ | ||
| check_bedrock_context_length_error, check_openai_context_length_error, get_model, | ||
| messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, | ||
| }; | ||
| use super::utils::{check_bedrock_context_length_error, get_model, handle_response}; | ||
| use crate::message::Message; | ||
| use crate::providers::openai_utils::{ | ||
| check_openai_context_length_error, get_openai_usage, messages_to_openai_spec, | ||
| openai_response_to_message, tools_to_openai_spec, | ||
| }; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: we should prob use the get_openai_usage and handle_response in this one as well.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| use mcp_core::tool::Tool; | ||
|
|
||
| pub const DATABRICKS_DEFAULT_MODEL: &str = "claude-3-5-sonnet-2"; | ||
|
|
||
| pub struct DatabricksProvider { | ||
| client: Client, | ||
| config: DatabricksProviderConfig, | ||
|
|
@@ -46,30 +49,7 @@ impl DatabricksProvider { | |
| } | ||
|
|
||
| fn get_usage(data: &Value) -> Result<Usage> { | ||
| let usage = data | ||
| .get("usage") | ||
| .ok_or_else(|| anyhow!("No usage data in response"))?; | ||
|
|
||
| let input_tokens = usage | ||
| .get("prompt_tokens") | ||
| .and_then(|v| v.as_i64()) | ||
| .map(|v| v as i32); | ||
|
|
||
| let output_tokens = usage | ||
| .get("completion_tokens") | ||
| .and_then(|v| v.as_i64()) | ||
| .map(|v| v as i32); | ||
|
|
||
| let total_tokens = usage | ||
| .get("total_tokens") | ||
| .and_then(|v| v.as_i64()) | ||
| .map(|v| v as i32) | ||
| .or_else(|| match (input_tokens, output_tokens) { | ||
| (Some(input), Some(output)) => Some(input + output), | ||
| _ => None, | ||
| }); | ||
|
|
||
| Ok(Usage::new(input_tokens, output_tokens, total_tokens)) | ||
| get_openai_usage(data) | ||
| } | ||
|
|
||
| async fn post(&self, payload: Value) -> Result<Value> { | ||
|
|
@@ -88,18 +68,7 @@ impl DatabricksProvider { | |
| .send() | ||
| .await?; | ||
|
|
||
| match response.status() { | ||
| StatusCode::OK => Ok(response.json().await?), | ||
| status if status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error() => { | ||
| // Implement retry logic here if needed | ||
| Err(anyhow!("Server error: {}", status)) | ||
| } | ||
| _ => { | ||
| let status = response.status(); | ||
| let err_text = response.text().await.unwrap_or_default(); | ||
| Err(anyhow!("Request failed: {}: {}", status, err_text)) | ||
| } | ||
| } | ||
| handle_response(payload, response).await? | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -112,7 +81,7 @@ impl Provider for DatabricksProvider { | |
| tools: &[Tool], | ||
| ) -> Result<(Message, ProviderUsage)> { | ||
| // Prepare messages and tools | ||
| let messages_spec = messages_to_openai_spec(messages, &self.config.image_format); | ||
| let messages_spec = messages_to_openai_spec(messages, &self.config.image_format, false); | ||
| let tools_spec = if !tools.is_empty() { | ||
| tools_to_openai_spec(tools)? | ||
| } else { | ||
|
|
@@ -179,6 +148,9 @@ mod tests { | |
| use super::*; | ||
| use crate::message::MessageContent; | ||
| use crate::providers::configs::ModelConfig; | ||
| use crate::providers::mock_server::{ | ||
| create_mock_open_ai_response, TEST_INPUT_TOKENS, TEST_OUTPUT_TOKENS, TEST_TOTAL_TOKENS, | ||
| }; | ||
| use wiremock::matchers::{body_json, header, method, path}; | ||
| use wiremock::{Mock, MockServer, ResponseTemplate}; | ||
|
|
||
|
|
@@ -188,19 +160,7 @@ mod tests { | |
| let mock_server = MockServer::start().await; | ||
|
|
||
| // Mock response for completion | ||
| let mock_response = json!({ | ||
| "choices": [{ | ||
| "message": { | ||
| "role": "assistant", | ||
| "content": "Hello!" | ||
| } | ||
| }], | ||
| "usage": { | ||
| "prompt_tokens": 10, | ||
| "completion_tokens": 25, | ||
| "total_tokens": 35 | ||
| } | ||
| }); | ||
| let mock_response = create_mock_open_ai_response("my-databricks-model", "Hello!"); | ||
|
|
||
| // Expected request body | ||
| let system = "You are a helpful assistant."; | ||
|
|
@@ -244,9 +204,9 @@ mod tests { | |
| } else { | ||
| panic!("Expected Text content"); | ||
| } | ||
| assert_eq!(reply_usage.usage.input_tokens, Some(10)); | ||
| assert_eq!(reply_usage.usage.output_tokens, Some(25)); | ||
| assert_eq!(reply_usage.usage.total_tokens, Some(35)); | ||
| assert_eq!(reply_usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); | ||
| assert_eq!(reply_usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); | ||
| assert_eq!(reply_usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this seems like something we should standardize in the providers for openai, databricks, and anthropic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done