Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions crates/goose/src/providers/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::providers::utils::{
};
use anyhow::Result;
use async_trait::async_trait;
use axum::http::HeaderMap;
use mcp_core::tool::Tool;
use reqwest::Client;
use serde_json::Value;
Expand Down Expand Up @@ -51,7 +52,6 @@ pub struct GoogleProvider {
#[serde(skip)]
client: Client,
host: String,
api_key: String,
model: ModelConfig,
}

Expand All @@ -70,14 +70,18 @@ impl GoogleProvider {
.get_param("GOOGLE_HOST")
.unwrap_or_else(|_| GOOGLE_API_HOST.to_string());

let mut headers = HeaderMap::new();
headers.insert("CONTENT_TYPE", "application/json".parse()?);
headers.insert("x-goog-api-key", api_key.parse()?);

let client = Client::builder()
.timeout(Duration::from_secs(600))
.default_headers(headers)
.build()?;

Ok(Self {
client,
host,
api_key,
model,
})
}
Expand All @@ -88,8 +92,8 @@ impl GoogleProvider {

let url = base_url
.join(&format!(
"v1beta/models/{}:generateContent?key={}",
self.model.model_name, self.api_key
"v1beta/models/{}:generateContent",
self.model.model_name
))
.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
Expand All @@ -103,7 +107,6 @@ impl GoogleProvider {
let response = self
.client
.post(url.clone()) // Clone the URL for each retry
.header("CONTENT_TYPE", "application/json")
.json(&payload)
.send()
.await;
Expand Down Expand Up @@ -192,7 +195,7 @@ impl Provider for GoogleProvider {
/// Fetch supported models from Google Generative Language API; returns Err on failure, Ok(None) if not present
async fn fetch_supported_models_async(&self) -> Result<Option<Vec<String>>, ProviderError> {
// List models via the v1beta/models endpoint
let url = format!("{}/v1beta/models?key={}", self.host, self.api_key);
let url = format!("{}/v1beta/models", self.host);
let response = self.client.get(&url).send().await?;
let json: serde_json::Value = response.json().await?;
// If 'models' field missing, return None
Expand Down
Loading