diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index 149499f5bde1..3716df0e6dc3 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -151,4 +151,59 @@ impl Provider for GroqProvider { super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) } + + /// Fetch supported models from Groq; returns Err on failure, Ok(None) if no models found + async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { + // Construct the Groq models endpoint + let base_url = url::Url::parse(&self.host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {}", e)))?; + let url = base_url.join("openai/v1/models").map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {}", e)) + })?; + + // Build the request with required headers + let request = self + .client + .get(url) + .bearer_auth(&self.api_key) + .header("Content-Type", "application/json"); + + // Send request + let response = request.send().await?; + let status = response.status(); + let payload: serde_json::Value = response.json().await.map_err(|_| { + ProviderError::RequestFailed("Response body is not valid JSON".to_string()) + })?; + + // Check for error response from API + if let Some(err_obj) = payload.get("error") { + let msg = err_obj + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("unknown error"); + return Err(ProviderError::Authentication(msg.to_string())); + } + + // Extract model names + if status == StatusCode::OK { + let data = payload + .get("data") + .and_then(|v| v.as_array()) + .ok_or_else(|| { + ProviderError::UsageError("Missing or invalid `data` field in response".into()) + })?; + + let mut model_names: Vec = data + .iter() + .filter_map(|m| m.get("id").and_then(Value::as_str).map(String::from)) + .collect(); + model_names.sort(); + Ok(Some(model_names)) + } else { + Err(ProviderError::RequestFailed(format!( + "Groq API returned error status: {}. Payload: {:?}", + status, payload + ))) + } + } }