diff --git a/crates/goose/examples/image_tool.rs b/crates/goose/examples/image_tool.rs index 23ebe1449c9e..24a75a745f9e 100644 --- a/crates/goose/examples/image_tool.rs +++ b/crates/goose/examples/image_tool.rs @@ -3,7 +3,7 @@ use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use dotenv::dotenv; use goose::{ message::Message, - providers::{databricks::DatabricksProvider, openai::OpenAiProvider}, + providers::{bedrock::BedrockProvider, databricks::DatabricksProvider, openai::OpenAiProvider}, }; use mcp_core::{ content::Content, @@ -21,6 +21,7 @@ async fn main() -> Result<()> { let providers: Vec> = vec![ Box::new(DatabricksProvider::default()), Box::new(OpenAiProvider::default()), + Box::new(BedrockProvider::default()), ]; for provider in providers { diff --git a/crates/goose/src/providers/formats/bedrock.rs b/crates/goose/src/providers/formats/bedrock.rs index 3c0ea40c2a50..29b3491585de 100644 --- a/crates/goose/src/providers/formats/bedrock.rs +++ b/crates/goose/src/providers/formats/bedrock.rs @@ -4,12 +4,14 @@ use std::path::Path; use anyhow::{anyhow, bail, Result}; use aws_sdk_bedrockruntime::types as bedrock; use aws_smithy_types::{Document, Number}; +use base64::Engine; use chrono::Utc; use mcp_core::{Content, ResourceContents, Role, Tool, ToolCall, ToolError, ToolResult}; use serde_json::Value; use super::super::base::Usage; use crate::message::{Message, MessageContent}; +use mcp_core::content::ImageContent; pub fn to_bedrock_message(message: &Message) -> Result { bedrock::Message::builder() @@ -31,9 +33,7 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result { bedrock::ContentBlock::Text("".to_string()) } - MessageContent::Image(_) => { - bail!("Image content is not supported by Bedrock provider yet") - } + MessageContent::Image(image) => bedrock::ContentBlock::Image(to_bedrock_image(image)?), MessageContent::Thinking(_) => { // Thinking blocks are not supported in Bedrock - skip bedrock::ContentBlock::Text("".to_string()) @@ -108,13 +108,17 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result Result { Ok(match content { Content::Text(text) => bedrock::ToolResultContentBlock::Text(text.text.to_string()), - Content::Image(_) => bail!("Image content is not supported by Bedrock provider yet"), + Content::Image(image) => bedrock::ToolResultContentBlock::Image(to_bedrock_image(image)?), Content::Resource(resource) => match &resource.resource { ResourceContents::TextResourceContents { text, .. } => { match to_bedrock_document(tool_use_id, &resource.resource)? { @@ -136,6 +140,33 @@ pub fn to_bedrock_role(role: &Role) -> bedrock::ConversationRole { } } +pub fn to_bedrock_image(image: &ImageContent) -> Result { + // Extract format from MIME type + let format = match image.mime_type.as_str() { + "image/png" => bedrock::ImageFormat::Png, + "image/jpeg" | "image/jpg" => bedrock::ImageFormat::Jpeg, + "image/gif" => bedrock::ImageFormat::Gif, + "image/webp" => bedrock::ImageFormat::Webp, + _ => bail!( + "Unsupported image format: {}. Bedrock supports png, jpeg, gif, webp", + image.mime_type + ), + }; + + // Create image source with base64 data + let source = bedrock::ImageSource::Bytes(aws_smithy_types::Blob::new( + base64::prelude::BASE64_STANDARD + .decode(&image.data) + .map_err(|e| anyhow!("Failed to decode base64 image data: {}", e))?, + )); + + // Build the image block + Ok(bedrock::ImageBlock::builder() + .format(format) + .source(source) + .build()?) +} + pub fn to_bedrock_tool_config(tools: &[Tool]) -> Result { Ok(bedrock::ToolConfiguration::builder() .set_tools(Some( @@ -315,3 +346,100 @@ pub fn from_bedrock_json(document: &Document) -> Result { ), }) } + +#[cfg(test)] +mod tests { + use super::*; + use anyhow::Result; + use mcp_core::content::ImageContent; + + // Base64 encoded 1x1 PNG image for testing + const TEST_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="; + + #[test] + fn test_to_bedrock_image_supported_formats() -> Result<()> { + let supported_formats = [ + "image/png", + "image/jpeg", + "image/jpg", + "image/gif", + "image/webp", + ]; + + for mime_type in supported_formats { + let image = ImageContent { + data: TEST_IMAGE_BASE64.to_string(), + mime_type: mime_type.to_string(), + annotations: None, + }; + + let result = to_bedrock_image(&image); + assert!(result.is_ok(), "Failed to convert {} format", mime_type); + } + + Ok(()) + } + + #[test] + fn test_to_bedrock_image_unsupported_format() { + let image = ImageContent { + data: TEST_IMAGE_BASE64.to_string(), + mime_type: "image/bmp".to_string(), + annotations: None, + }; + + let result = to_bedrock_image(&image); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Unsupported image format: image/bmp")); + assert!(error_msg.contains("Bedrock supports png, jpeg, gif, webp")); + } + + #[test] + fn test_to_bedrock_image_invalid_base64() { + let image = ImageContent { + data: "invalid_base64_data!!!".to_string(), + mime_type: "image/png".to_string(), + annotations: None, + }; + + let result = to_bedrock_image(&image); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Failed to decode base64 image data")); + } + + #[test] + fn test_to_bedrock_message_content_image() -> Result<()> { + let image = ImageContent { + data: TEST_IMAGE_BASE64.to_string(), + mime_type: "image/png".to_string(), + annotations: None, + }; + + let message_content = MessageContent::Image(image); + let result = to_bedrock_message_content(&message_content)?; + + // Verify we get an Image content block + assert!(matches!(result, bedrock::ContentBlock::Image(_))); + + Ok(()) + } + + #[test] + fn test_to_bedrock_tool_result_content_block_image() -> Result<()> { + let image = ImageContent { + data: TEST_IMAGE_BASE64.to_string(), + mime_type: "image/png".to_string(), + annotations: None, + }; + + let content = Content::Image(image); + let result = to_bedrock_tool_result_content_block("test_id", &content)?; + + // Verify the wrapper correctly converts Content::Image to ToolResultContentBlock::Image + assert!(matches!(result, bedrock::ToolResultContentBlock::Image(_))); + + Ok(()) + } +} diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 9cd229ce9a3a..c4884b7c437b 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -254,11 +254,102 @@ impl ProviderTester { Ok(()) } + async fn test_image_content_support(&self) -> Result<()> { + use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; + use mcp_core::content::ImageContent; + use std::fs; + + // Try to read the test image + let image_path = "crates/goose/examples/test_assets/test_image.png"; + let image_data = match fs::read(image_path) { + Ok(data) => data, + Err(_) => { + println!( + "Test image not found at {}, skipping image test", + image_path + ); + return Ok(()); + } + }; + + let base64_image = BASE64.encode(image_data); + let image_content = ImageContent { + data: base64_image, + mime_type: "image/png".to_string(), + annotations: None, + }; + + // Test 1: Direct image message + let message_with_image = + Message::user().with_image(image_content.data.clone(), image_content.mime_type.clone()); + + let result = self + .provider + .complete( + "You are a helpful assistant. Describe what you see in the image briefly.", + &[message_with_image], + &[], + ) + .await; + + println!("=== {}::image_content_support ===", self.name); + let (response, _) = result?; + println!("Image response: {:?}", response); + // Verify we got a text response + assert!( + response + .content + .iter() + .any(|content| matches!(content, MessageContent::Text(_))), + "Expected text response for image" + ); + println!("==================="); + + // Test 2: Tool response with image (this should be handled gracefully) + let screenshot_tool = Tool::new( + "get_screenshot", + "Get a screenshot of the current screen", + serde_json::json!({ + "type": "object", + "properties": {} + }), + None, + ); + + let user_message = Message::user().with_text("Take a screenshot please"); + let tool_request = Message::assistant().with_tool_request( + "test_id", + Ok(mcp_core::tool::ToolCall::new( + "get_screenshot", + serde_json::json!({}), + )), + ); + let tool_response = + Message::user().with_tool_response("test_id", Ok(vec![Content::Image(image_content)])); + + let result2 = self + .provider + .complete( + "You are a helpful assistant.", + &[user_message, tool_request, tool_response], + &[screenshot_tool], + ) + .await; + + println!("=== {}::tool_image_response ===", self.name); + let (response, _) = result2?; + println!("Tool image response: {:?}", response); + println!("==================="); + + Ok(()) + } + /// Run all provider tests async fn run_test_suite(&self) -> Result<()> { self.test_basic_response().await?; self.test_tool_usage().await?; self.test_context_length_exceeded_error().await?; + self.test_image_content_support().await?; Ok(()) } }