diff --git a/crates/goose/src/providers/formats/gcpvertexai.rs b/crates/goose/src/providers/formats/gcpvertexai.rs index 8d60d6018593..8b2372cb8df1 100644 --- a/crates/goose/src/providers/formats/gcpvertexai.rs +++ b/crates/goose/src/providers/formats/gcpvertexai.rs @@ -8,6 +8,18 @@ use serde_json::Value; use std::fmt; +pub type StreamingMessageStream = std::pin::Pin< + Box< + dyn futures::Stream< + Item = anyhow::Result<( + Option, + Option, + )>, + > + Send + + 'static, + >, +>; + /// Sensible default values of Google Cloud Platform (GCP) locations for model deployment. /// /// Each variant corresponds to a specific GCP region where models can be hosted. @@ -367,6 +379,21 @@ pub fn get_usage(data: &Value, request_context: &RequestContext) -> Result( + stream: S, + request_context: &RequestContext, +) -> StreamingMessageStream +where + S: futures::Stream> + Unpin + Send + 'static, +{ + match request_context.provider() { + ModelProvider::Anthropic => Box::pin(anthropic::response_to_streaming_message(stream)), + ModelProvider::Google | ModelProvider::MaaS(_) => { + Box::pin(google::response_to_streaming_message(stream)) + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index a1d45c556e24..93e18f4d4b1c 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -1,21 +1,26 @@ +use std::io; use std::time::Duration; use anyhow::Result; +use async_stream::try_stream; use async_trait::async_trait; +use futures::StreamExt; +use futures::TryStreamExt; use once_cell::sync::Lazy; use reqwest::{Client, StatusCode}; use serde_json::Value; use tokio::time::sleep; +use tokio_util::io::StreamReader; use url::Url; use crate::conversation::message::Message; use crate::model::ModelConfig; -use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; +use crate::providers::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage}; use crate::providers::errors::ProviderError; use crate::providers::formats::gcpvertexai::{ - create_request, get_usage, response_to_message, ClaudeVersion, GcpVertexAIModel, GeminiVersion, - ModelProvider, RequestContext, + create_request, get_usage, response_to_message, response_to_streaming_message, ClaudeVersion, + GcpVertexAIModel, GeminiVersion, ModelProvider, RequestContext, }; use crate::providers::formats::gcpvertexai::GcpLocation::Iowa; @@ -40,6 +45,66 @@ const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 320_000; static STATUS_API_OVERLOADED: Lazy = Lazy::new(|| StatusCode::from_u16(529).expect("Valid status code 529 for API_OVERLOADED")); +fn rate_limit_error_message(response_text: &str) -> String { + let cite = "See https://cloud.google.com/vertex-ai/generative-ai/docs/error-code-429"; + if response_text.contains("Exceeded the Provisioned Throughput") { + format!("Exceeded the Provisioned Throughput: {cite}") + } else { + format!("Pay-as-you-go resource exhausted: {cite}") + } +} + +const OVERLOADED_ERROR_MSG: &str = + "Vertex AI Provider API is temporarily overloaded. This is similar to a rate limit \ + error but indicates backend processing capacity issues."; + +fn build_vertex_url( + host: &str, + configured_location: &str, + project_id: &str, + model_name: &str, + provider: ModelProvider, + target_location: &str, + streaming: bool, +) -> Result { + let host_url = if configured_location == target_location { + host.to_string() + } else { + host.replace(configured_location, target_location) + }; + + let base_url = + Url::parse(&host_url).map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))?; + + let endpoint = match (&provider, streaming) { + (ModelProvider::Anthropic, true) => "streamRawPredict", + (ModelProvider::Anthropic, false) => "rawPredict", + (ModelProvider::Google, true) => "streamGenerateContent", + (ModelProvider::Google, false) => "generateContent", + (ModelProvider::MaaS(_), true) => "streamGenerateContent", + (ModelProvider::MaaS(_), false) => "generateContent", + }; + + let path = format!( + "v1/projects/{}/locations/{}/publishers/{}/models/{}:{}", + project_id, + target_location, + provider.as_str(), + model_name, + endpoint + ); + + let mut url = base_url + .join(&path) + .map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))?; + + if streaming && !matches!(provider, ModelProvider::Anthropic) { + url.set_query(Some("alt=sse")); + } + + Ok(url) +} + /// Represents errors specific to GCP Vertex AI operations. #[derive(Debug, thiserror::Error)] enum GcpVertexAIError { @@ -172,94 +237,48 @@ impl GcpVertexAIProvider { .map_err(|e| GcpVertexAIError::AuthError(e.to_string())) } - /// Constructs the appropriate API endpoint URL for a given provider. - /// - /// # Arguments - /// * `provider` - The model provider (Anthropic or Google) - /// * `location` - The GCP location for model deployment fn build_request_url( &self, provider: ModelProvider, location: &str, + streaming: bool, ) -> Result { - // Create host URL for the specified location - let host_url = if self.location == location { - &self.host - } else { - // Only allocate a new string if location differs - &self.host.replace(&self.location, location) - }; - - let base_url = - Url::parse(host_url).map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))?; - - // Determine endpoint based on provider type - let endpoint = match provider { - ModelProvider::Anthropic => "streamRawPredict", - ModelProvider::Google => "generateContent", - ModelProvider::MaaS(_) => "generateContent", - }; - - // Construct path for URL - let path = format!( - "v1/projects/{}/locations/{}/publishers/{}/models/{}:{}", - self.project_id, + build_vertex_url( + &self.host, + &self.location, + &self.project_id, + &self.model.model_name, + provider, location, - provider.as_str(), - self.model.model_name, - endpoint - ); - - base_url - .join(&path) - .map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string())) + streaming, + ) } - /// Makes an authenticated POST request to the Vertex AI API at a specific location. - /// Includes retry logic for 429 (Too Many Requests) and 529 (API Overloaded) errors. - /// - /// # Arguments - /// * `payload` - The request payload to send - /// * `context` - Request context containing model information - /// * `location` - The GCP location for the request - async fn post_with_location( + async fn send_request_with_retry( &self, + url: Url, payload: &Value, - context: &RequestContext, - location: &str, - ) -> Result { - let url = self - .build_request_url(context.provider(), location) - .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; - - // Initialize separate counters for different error types + ) -> Result { let mut rate_limit_attempts = 0; let mut overloaded_attempts = 0; let mut last_error = None; + let max_retries = self.retry_config.max_retries; loop { - // Check if we've exceeded max retries - if rate_limit_attempts > self.retry_config.max_retries - && overloaded_attempts > self.retry_config.max_retries - { - let error_msg = format!( - "Exceeded maximum retry attempts ({}) for rate limiting errors", - self.retry_config.max_retries + if rate_limit_attempts > max_retries && overloaded_attempts > max_retries { + return Err( + last_error.unwrap_or_else(|| ProviderError::RateLimitExceeded { + details: format!("Exceeded maximum retry attempts ({max_retries})"), + retry_delay: None, + }), ); - tracing::error!("{}", error_msg); - return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded { - details: error_msg, - retry_delay: None, - })); } - // Get a fresh auth token for each attempt let auth_header = self .get_auth_header() .await .map_err(|e| ProviderError::Authentication(e.to_string()))?; - // Make the request let response = self .client .post(url.clone()) @@ -271,162 +290,143 @@ impl GcpVertexAIProvider { let status = response.status(); - // Handle 429 Too Many Requests and 529 API Overloaded errors - match status { - status if status == StatusCode::TOO_MANY_REQUESTS => { - rate_limit_attempts += 1; - - if rate_limit_attempts > self.retry_config.max_retries { - let error_msg = format!( - "Exceeded maximum retry attempts ({}) for rate limiting (429) errors", - self.retry_config.max_retries - ); - tracing::error!("{}", error_msg); - return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded { - details: error_msg, + if status == StatusCode::TOO_MANY_REQUESTS { + rate_limit_attempts += 1; + if rate_limit_attempts > max_retries { + return Err( + last_error.unwrap_or_else(|| ProviderError::RateLimitExceeded { + details: format!("Exceeded max retries ({max_retries}) for 429"), retry_delay: None, - })); - } - - // Try to parse response for more detailed error info - let cite_gcp_vertex_429 = - "See https://cloud.google.com/vertex-ai/generative-ai/docs/error-code-429"; - let response_text = response.text().await.unwrap_or_default(); - - let error_message = - if response_text.contains("Exceeded the Provisioned Throughput") { - // Handle 429 rate limit due to throughput limits - format!("Exceeded the Provisioned Throughput: {cite_gcp_vertex_429}") - } else { - // Handle generic 429 rate limit - format!("Pay-as-you-go resource exhausted: {cite_gcp_vertex_429}") - }; - - tracing::warn!( - "Rate limit exceeded error (429) (attempt {}/{}): {}. Retrying after backoff...", - rate_limit_attempts, - self.retry_config.max_retries, - error_message + }), ); - - // Store the error in case we need to return it after max retries - last_error = Some(ProviderError::RateLimitExceeded { - details: error_message, - retry_delay: None, - }); - - // Calculate and apply the backoff delay - let delay = self.retry_config.delay_for_attempt(rate_limit_attempts); - tracing::info!("Backing off for {:?} before retry (rate limit 429)", delay); - sleep(delay).await; } - status if status == *STATUS_API_OVERLOADED => { - overloaded_attempts += 1; - - if overloaded_attempts > self.retry_config.max_retries { - let error_msg = format!( - "Exceeded maximum retry attempts ({}) for API overloaded (529) errors", - self.retry_config.max_retries - ); - tracing::error!("{}", error_msg); - return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded { - details: error_msg, + let msg = rate_limit_error_message(&response.text().await.unwrap_or_default()); + tracing::warn!("429 (attempt {rate_limit_attempts}/{max_retries}): {msg}"); + last_error = Some(ProviderError::RateLimitExceeded { + details: msg, + retry_delay: None, + }); + sleep(self.retry_config.delay_for_attempt(rate_limit_attempts)).await; + } else if status == *STATUS_API_OVERLOADED { + overloaded_attempts += 1; + if overloaded_attempts > max_retries { + return Err( + last_error.unwrap_or_else(|| ProviderError::RateLimitExceeded { + details: format!("Exceeded max retries ({max_retries}) for 529"), retry_delay: None, - })); - } - - // Handle 529 Overloaded error (https://docs.anthropic.com/en/api/errors) - let error_message = - "Vertex AI Provider API is temporarily overloaded. This is similar to a rate limit \ - error but indicates backend processing capacity issues." - .to_string(); - - tracing::warn!( - "API overloaded error (529) (attempt {}/{}): {}. Retrying after backoff...", - overloaded_attempts, - self.retry_config.max_retries, - error_message - ); - - // Store the error in case we need to return it after max retries - last_error = Some(ProviderError::RateLimitExceeded { - details: error_message, - retry_delay: None, - }); - - // Calculate and apply the backoff delay - let delay = self.retry_config.delay_for_attempt(overloaded_attempts); - tracing::info!( - "Backing off for {:?} before retry (API overloaded 529)", - delay + }), ); - sleep(delay).await; - } - // For any other status codes, process normally - _ => { - let response_json = response.json::().await.map_err(|e| { - ProviderError::RequestFailed(format!("Failed to parse response: {e}")) - })?; - - return match status { - StatusCode::OK => Ok(response_json), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - tracing::debug!( - "Authentication failed. Status: {status}, Payload: {payload:?}" - ); - Err(ProviderError::Authentication(format!( - "Authentication failed: {response_json:?}" - ))) - } - _ => { - tracing::debug!( - "Request failed. Status: {status}, Response: {response_json:?}" - ); - Err(ProviderError::RequestFailed(format!( - "Request failed with status {status}: {response_json:?}" - ))) - } - }; } + tracing::warn!( + "529 (attempt {overloaded_attempts}/{max_retries}): {OVERLOADED_ERROR_MSG}" + ); + last_error = Some(ProviderError::RateLimitExceeded { + details: OVERLOADED_ERROR_MSG.to_string(), + retry_delay: None, + }); + sleep(self.retry_config.delay_for_attempt(overloaded_attempts)).await; + } else if status == StatusCode::OK { + return Ok(response); + } else if status == StatusCode::UNAUTHORIZED || status == StatusCode::FORBIDDEN { + return Err(ProviderError::Authentication(format!( + "Authentication failed with status: {status}" + ))); + } else { + let response_text = response.text().await.unwrap_or_default(); + return Err(ProviderError::RequestFailed(format!( + "Request failed with status {status}: {response_text}" + ))); } } } - /// Makes an authenticated POST request to the Vertex AI API with fallback for invalid locations. - /// - /// # Arguments - /// * `payload` - The request payload to send - /// * `context` - Request context containing model information + async fn post_with_location( + &self, + payload: &Value, + context: &RequestContext, + location: &str, + ) -> Result { + let url = self + .build_request_url(context.provider(), location, false) + .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; + + let response = self.send_request_with_retry(url, payload).await?; + + response + .json::() + .await + .map_err(|e| ProviderError::RequestFailed(format!("Failed to parse response: {e}"))) + } + async fn post( &self, payload: &Value, context: &RequestContext, ) -> Result { - // Try with user-specified location first let result = self .post_with_location(payload, context, &self.location) .await; - // If location is already the known location for the model or request succeeded, return result if self.location == context.model.known_location().to_string() || result.is_ok() { return result; } - // Check if we should retry with the model's known location match &result { Err(ProviderError::RequestFailed(msg)) => { let model_name = context.model.to_string(); let configured_location = &self.location; let known_location = context.model.known_location().to_string(); - tracing::error!( + tracing::warn!( "Trying known location {known_location} for {model_name} instead of {configured_location}: {msg}" ); self.post_with_location(payload, context, &known_location) .await } - // For any other error, return the original result + _ => result, + } + } + + async fn post_stream_with_location( + &self, + payload: &Value, + context: &RequestContext, + location: &str, + ) -> Result { + let url = self + .build_request_url(context.provider(), location, true) + .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; + + self.send_request_with_retry(url, payload).await + } + + async fn post_stream( + &self, + payload: &Value, + context: &RequestContext, + ) -> Result { + let result = self + .post_stream_with_location(payload, context, &self.location) + .await; + + if self.location == context.model.known_location().to_string() || result.is_ok() { + return result; + } + + match &result { + Err(ProviderError::RequestFailed(msg)) => { + let model_name = context.model.to_string(); + let configured_location = &self.location; + let known_location = context.model.known_location().to_string(); + + tracing::warn!( + "Trying known location {known_location} for {model_name} instead of {configured_location}: {msg}" + ); + + self.post_stream_with_location(payload, context, &known_location) + .await + } _ => result, } } @@ -538,6 +538,56 @@ impl Provider for GcpVertexAIProvider { fn get_model_config(&self) -> ModelConfig { self.model.clone() } + + fn supports_streaming(&self) -> bool { + true + } + + async fn stream( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result { + let model_config = self.get_model_config(); + let (mut request, context) = create_request(&model_config, system, messages, tools)?; + + if matches!(context.provider(), ModelProvider::Anthropic) { + if let Some(obj) = request.as_object_mut() { + obj.insert("stream".to_string(), Value::Bool(true)); + } + } + + let mut log = RequestLog::start(&model_config, &request)?; + + let response = self + .post_stream(&request, &context) + .await + .inspect_err(|e| { + let _ = log.error(e); + })?; + + let stream = response.bytes_stream().map_err(io::Error::other); + + let context_clone = context.clone(); + Ok(Box::pin(try_stream! { + let stream_reader = StreamReader::new(stream); + let framed = tokio_util::codec::FramedRead::new( + stream_reader, + tokio_util::codec::LinesCodec::new(), + ) + .map_err(anyhow::Error::from); + + let mut message_stream = response_to_streaming_message(framed, &context_clone); + + while let Some(message) = message_stream.next().await { + let (message, usage) = message + .map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; + log.write(&message, usage.as_ref().map(|u| &u.usage))?; + yield (message, usage); + } + })) + } } #[cfg(test)] @@ -594,31 +644,63 @@ mod tests { } #[test] - fn test_url_construction() { - use url::Url; - - let model_config = ModelConfig::new_or_fail("claude-sonnet-4-20250514"); - let context = RequestContext::new(&model_config.model_name).unwrap(); - let api_model_id = context.model.to_string(); - - let host = "https://us-east5-aiplatform.googleapis.com"; - let project_id = "test-project"; - let location = "us-east5"; - - let path = format!( - "v1/projects/{}/locations/{}/publishers/{}/models/{}:{}", - project_id, - location, - ModelProvider::Anthropic.as_str(), - api_model_id, - "streamRawPredict" - ); + fn test_build_vertex_url_endpoints() { + let anthropic_url = build_vertex_url( + "https://us-east5-aiplatform.googleapis.com", + "us-east5", + "test-project", + "claude-sonnet-4@20250514", + ModelProvider::Anthropic, + "us-east5", + false, + ) + .unwrap(); + assert!(anthropic_url.as_str().contains(":rawPredict")); + + let anthropic_stream = build_vertex_url( + "https://us-east5-aiplatform.googleapis.com", + "us-east5", + "test-project", + "claude-sonnet-4@20250514", + ModelProvider::Anthropic, + "us-east5", + true, + ) + .unwrap(); + assert!(anthropic_stream.as_str().contains(":streamRawPredict")); + assert!(anthropic_stream.query().is_none()); + + let google_stream = build_vertex_url( + "https://us-central1-aiplatform.googleapis.com", + "us-central1", + "test-project", + "gemini-2.5-flash", + ModelProvider::Google, + "us-central1", + true, + ) + .unwrap(); + assert!(google_stream.as_str().contains(":streamGenerateContent")); + assert_eq!(google_stream.query(), Some("alt=sse")); + } - let url = Url::parse(host).unwrap().join(&path).unwrap(); + #[test] + fn test_build_vertex_url_location_replacement() { + let url = build_vertex_url( + "https://us-east5-aiplatform.googleapis.com", + "us-east5", + "test-project", + "claude-sonnet-4@20250514", + ModelProvider::Anthropic, + "europe-west1", + false, + ) + .unwrap(); - assert!(url.as_str().contains("publishers/anthropic")); - assert!(url.as_str().contains("projects/test-project")); - assert!(url.as_str().contains("locations/us-east5")); + assert!(url + .as_str() + .contains("europe-west1-aiplatform.googleapis.com")); + assert!(url.as_str().contains("locations/europe-west1")); } #[test]