diff --git a/Cargo.lock b/Cargo.lock index 02967400a288..2322aa5c6d6a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -704,6 +704,7 @@ dependencies = [ "aws-credential-types", "aws-sigv4", "aws-smithy-async", + "aws-smithy-eventstream", "aws-smithy-http 0.60.12", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -767,6 +768,29 @@ dependencies = [ "tracing", ] +[[package]] +name = "aws-sdk-sagemakerruntime" +version = "1.63.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3188bb9f962a9e1781c917dbe7f016ab9430e4bd81ba7daf422e58d86a3595" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-eventstream", + "aws-smithy-http 0.61.1", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", +] + [[package]] name = "aws-sdk-sso" version = "1.61.0" @@ -841,6 +865,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9bfe75fad52793ce6dec0dc3d4b1f388f038b5eb866c8d4d7f3a8e21b5ea5051" dependencies = [ "aws-credential-types", + "aws-smithy-eventstream", "aws-smithy-http 0.60.12", "aws-smithy-runtime-api", "aws-smithy-types", @@ -3394,6 +3419,7 @@ dependencies = [ "async-trait", "aws-config", "aws-sdk-bedrockruntime", + "aws-sdk-sagemakerruntime", "aws-smithy-types", "axum", "base64 0.21.7", diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 4aa3f12975da..bf2a23d96387 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -68,6 +68,9 @@ aws-config = { version = "1.5.16", features = ["behavior-version-latest"] } aws-smithy-types = "1.2.13" aws-sdk-bedrockruntime = "1.74.0" +# For SageMaker TGI provider +aws-sdk-sagemakerruntime = "1.62.0" + # For GCP Vertex AI provider auth jsonwebtoken = "9.3.1" diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index 22bdaa95ddef..af4fc2e34f99 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -14,6 +14,7 @@ use super::{ ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider, + sagemaker_tgi::SageMakerTgiProvider, snowflake::SnowflakeProvider, venice::VeniceProvider, }; @@ -48,6 +49,7 @@ pub fn providers() -> Vec { OllamaProvider::metadata(), OpenAiProvider::metadata(), OpenRouterProvider::metadata(), + SageMakerTgiProvider::metadata(), VeniceProvider::metadata(), SnowflakeProvider::metadata(), ] @@ -122,6 +124,7 @@ fn create_provider(name: &str, model: ModelConfig) -> Result> "openrouter" => Ok(Arc::new(OpenRouterProvider::from_env(model)?)), "gcp_vertex_ai" => Ok(Arc::new(GcpVertexAIProvider::from_env(model)?)), "google" => Ok(Arc::new(GoogleProvider::from_env(model)?)), + "sagemaker_tgi" => Ok(Arc::new(SageMakerTgiProvider::from_env(model)?)), "venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)), "snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)), "github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)), diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 2f6f1f87bd2c..ccda29a431d4 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -18,6 +18,7 @@ pub mod oauth; pub mod ollama; pub mod openai; pub mod openrouter; +pub mod sagemaker_tgi; pub mod snowflake; pub mod toolshim; pub mod utils; diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs new file mode 100644 index 000000000000..19b95deb83f6 --- /dev/null +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -0,0 +1,365 @@ +use std::collections::HashMap; +use std::time::Duration; + +use anyhow::Result; +use async_trait::async_trait; +use aws_config; +use aws_sdk_bedrockruntime::config::ProvideCredentials; +use aws_sdk_sagemakerruntime::Client as SageMakerClient; +use mcp_core::Tool; +use serde_json::{json, Value}; +use tokio::time::sleep; + +use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::errors::ProviderError; +use super::utils::emit_debug_trace; +use crate::message::{Message, MessageContent}; +use crate::model::ModelConfig; +use chrono::Utc; +use mcp_core::content::TextContent; +use mcp_core::role::Role; + +pub const SAGEMAKER_TGI_DOC_LINK: &str = + "https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html"; + +pub const SAGEMAKER_TGI_DEFAULT_MODEL: &str = "sagemaker-tgi-endpoint"; + +#[derive(Debug, serde::Serialize)] +pub struct SageMakerTgiProvider { + #[serde(skip)] + sagemaker_client: SageMakerClient, + endpoint_name: String, + model: ModelConfig, +} + +impl SageMakerTgiProvider { + pub fn from_env(model: ModelConfig) -> Result { + let config = crate::config::Config::global(); + + // Get SageMaker endpoint name (just the name, not full URL) + let endpoint_name: String = config.get_param("SAGEMAKER_ENDPOINT_NAME").map_err(|_| { + anyhow::anyhow!("SAGEMAKER_ENDPOINT_NAME is required for SageMaker TGI provider") + })?; + + // Attempt to load config and secrets to get AWS_ prefixed keys + let set_aws_env_vars = |res: Result, _>| { + if let Ok(map) = res { + map.into_iter() + .filter(|(key, _)| key.starts_with("AWS_")) + .filter_map(|(key, value)| value.as_str().map(|s| (key, s.to_string()))) + .for_each(|(key, s)| std::env::set_var(key, s)); + } + }; + + set_aws_env_vars(config.load_values()); + set_aws_env_vars(config.load_secrets()); + + let aws_config = futures::executor::block_on(aws_config::load_from_env()); + + // Validate credentials + futures::executor::block_on( + aws_config + .credentials_provider() + .unwrap() + .provide_credentials(), + )?; + + // Create client with longer timeout for model initialization + let timeout_config = aws_config::timeout::TimeoutConfig::builder() + .operation_timeout(Duration::from_secs(300)) // 5 minutes for cold starts + .build(); + + let config_with_timeout = aws_config + .into_builder() + .timeout_config(timeout_config) + .build(); + + let sagemaker_client = SageMakerClient::new(&config_with_timeout); + + Ok(Self { + sagemaker_client, + endpoint_name, + model, + }) + } + + fn create_tgi_request(&self, system: &str, messages: &[Message]) -> Result { + // Create a simplified prompt for TGI models using recent user and assistant messages. + // Uses a minimal system prompt and avoids HTML or tool-related formatting. + let mut prompt = String::new(); + + // Use a very simple system prompt if provided, but ensure it doesn't contain HTML instructions + if !system.is_empty() + && !system.contains("Available tools") + && system.len() < 200 + && !system.contains("HTML") + && !system.contains("markdown") + { + prompt.push_str(&format!("System: {}\n\n", system)); + } else { + // Use a minimal system prompt for TGI that explicitly avoids HTML + prompt.push_str("System: You are a helpful AI assistant. Provide responses in plain text only. Do not use HTML tags, markup, or formatting.\n\n"); + } + + // Only include the most recent user messages to avoid overwhelming the model + let recent_messages: Vec<_> = messages.iter().rev().take(3).collect(); + for message in recent_messages.iter().rev() { + match &message.role { + Role::User => { + prompt.push_str("User: "); + for content in &message.content { + if let MessageContent::Text(text) = content { + prompt.push_str(&text.text); + } + } + prompt.push_str("\n\n"); + } + Role::Assistant => { + prompt.push_str("Assistant: "); + for content in &message.content { + if let MessageContent::Text(text) = content { + // Skip responses that look like tool descriptions or contain HTML + if !text.text.contains("__") + && !text.text.contains("Available tools") + && !text.text.contains("<") + { + prompt.push_str(&text.text); + } + } + } + prompt.push_str("\n\n"); + } + } + } + + prompt.push_str("Assistant: "); + + // Skip tool descriptions entirely for TGI models to avoid confusion + // TGI models don't support tools natively and including tool descriptions + // causes them to mimic that format in their responses + + // Build TGI request with reasonable parameters + let request = json!({ + "inputs": prompt, + "parameters": { + "max_new_tokens": self.model.max_tokens.unwrap_or(150), + "temperature": self.model.temperature.unwrap_or(0.7), + "do_sample": true, + "return_full_text": false + } + }); + + Ok(request) + } + + async fn invoke_endpoint(&self, payload: Value) -> Result { + let body = serde_json::to_string(&payload).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to serialize request: {}", e)) + })?; + + let response = self + .sagemaker_client + .invoke_endpoint() + .endpoint_name(&self.endpoint_name) + .content_type("application/json") + .body(body.into_bytes().into()) + .send() + .await + .map_err(|e| ProviderError::RequestFailed(format!("SageMaker invoke failed: {}", e)))?; + + let response_body = response + .body + .as_ref() + .ok_or_else(|| ProviderError::RequestFailed("Empty response body".to_string()))?; + let response_text = std::str::from_utf8(response_body.as_ref()).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to decode response: {}", e)) + })?; + + serde_json::from_str(response_text).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse response JSON: {}", e)) + }) + } + + fn parse_tgi_response(&self, response: Value) -> Result { + // Handle standard TGI response: [{"generated_text": "..."}] + let response_array = response + .as_array() + .ok_or_else(|| ProviderError::RequestFailed("Expected array response".to_string()))?; + + if response_array.is_empty() { + return Err(ProviderError::RequestFailed( + "Empty response array".to_string(), + )); + } + + let first_result = &response_array[0]; + let generated_text = first_result + .get("generated_text") + .and_then(|v| v.as_str()) + .ok_or_else(|| { + ProviderError::RequestFailed("No generated_text in response".to_string()) + })?; + + // Strip any HTML tags that might have been generated + let clean_text = self.strip_html_tags(generated_text); + + Ok(Message { + role: Role::Assistant, + created: Utc::now().timestamp(), + content: vec![MessageContent::Text(TextContent { + text: clean_text, + annotations: None, + })], + }) + } + + /// Strip HTML tags from text to ensure clean output + fn strip_html_tags(&self, text: &str) -> String { + // Simple regex-free approach to strip common HTML tags + let mut result = text.to_string(); + + // Remove common HTML tags like , , , , etc. + let tags_to_remove = [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "
", + "
", + "

", + "

", + "
", + "
", + "", + "", + ]; + + for tag in &tags_to_remove { + result = result.replace(tag, ""); + } + + // Remove any remaining HTML-like tags using a simple pattern + // This is a basic implementation - for production use, consider using a proper HTML parser + while let Some(start) = result.find('<') { + if let Some(end) = result[start..].find('>') { + result.replace_range(start..start + end + 1, ""); + } else { + break; + } + } + + result.trim().to_string() + } +} + +impl Default for SageMakerTgiProvider { + fn default() -> Self { + let model = ModelConfig::new(SageMakerTgiProvider::metadata().default_model); + SageMakerTgiProvider::from_env(model).expect("Failed to initialize SageMaker TGI provider") + } +} + +#[async_trait] +impl Provider for SageMakerTgiProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::new( + "sagemaker_tgi", + "Amazon SageMaker TGI", + "Run Text Generation Inference models through Amazon SageMaker endpoints. Requires AWS credentials and a SageMaker endpoint URL.", + SAGEMAKER_TGI_DEFAULT_MODEL, + vec![SAGEMAKER_TGI_DEFAULT_MODEL], + SAGEMAKER_TGI_DOC_LINK, + vec![ + ConfigKey::new("SAGEMAKER_ENDPOINT_NAME", false, false, None), + ConfigKey::new("AWS_REGION", true, false, Some("us-east-1")), + ConfigKey::new("AWS_PROFILE", true, false, Some("default")), + ], + ) + } + + fn get_model_config(&self) -> ModelConfig { + self.model.clone() + } + + #[tracing::instrument( + skip(self, system, messages, tools), + fields(model_config, input, output, input_tokens, output_tokens, total_tokens) + )] + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + let model_name = &self.model.model_name; + + let request_payload = self.create_tgi_request(system, messages).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to create request: {}", e)) + })?; + + // Retry configuration + const MAX_RETRIES: u32 = 3; + const INITIAL_BACKOFF_MS: u64 = 1000; // 1 second + const MAX_BACKOFF_MS: u64 = 30000; // 30 seconds + + let mut attempts = 0; + let mut backoff_ms = INITIAL_BACKOFF_MS; + + loop { + attempts += 1; + + match self.invoke_endpoint(request_payload.clone()).await { + Ok(response) => { + let message = self.parse_tgi_response(response)?; + + // TGI doesn't provide usage statistics, so we estimate + let usage = Usage { + input_tokens: Some(0), // Would need to tokenize input to get accurate count + output_tokens: Some(0), // Would need to tokenize output to get accurate count + total_tokens: Some(0), + }; + + // Add debug trace + let debug_payload = serde_json::json!({ + "system": system, + "messages": messages, + "tools": tools + }); + emit_debug_trace( + &self.model, + &debug_payload, + &serde_json::to_value(&message).unwrap_or_default(), + &usage, + ); + + let provider_usage = ProviderUsage::new(model_name.to_string(), usage); + return Ok((message, provider_usage)); + } + Err(err) => { + if attempts > MAX_RETRIES { + return Err(err); + } + + // Log retry attempt + tracing::warn!( + "SageMaker TGI request failed (attempt {}/{}), retrying in {} ms: {:?}", + attempts, + MAX_RETRIES, + backoff_ms, + err + ); + + // Wait before retry + sleep(Duration::from_millis(backoff_ms)).await; + backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS); + } + } + } + } +} diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 4d41251fc440..842f53789963 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -490,6 +490,17 @@ async fn test_snowflake_provider() -> Result<()> { .await } +#[tokio::test] +async fn test_sagemaker_tgi_provider() -> Result<()> { + test_provider( + "SageMakerTgi", + &["SAGEMAKER_ENDPOINT_NAME"], + None, + goose::providers::sagemaker_tgi::SageMakerTgiProvider::default, + ) + .await +} + // Print the final test report #[ctor::dtor] fn print_test_report() {