Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crates/goose/examples/image_tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use dotenv::dotenv;
use goose::{
message::Message,
providers::{databricks::DatabricksProvider, openai::OpenAiProvider},
providers::{bedrock::BedrockProvider, databricks::DatabricksProvider, openai::OpenAiProvider},
};
use mcp_core::{
content::Content,
Expand All @@ -21,6 +21,7 @@ async fn main() -> Result<()> {
let providers: Vec<Box<dyn goose::providers::base::Provider + Send + Sync>> = vec![
Box::new(DatabricksProvider::default()),
Box::new(OpenAiProvider::default()),
Box::new(BedrockProvider::default()),
];

for provider in providers {
Expand Down
136 changes: 132 additions & 4 deletions crates/goose/src/providers/formats/bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ use std::path::Path;
use anyhow::{anyhow, bail, Result};
use aws_sdk_bedrockruntime::types as bedrock;
use aws_smithy_types::{Document, Number};
use base64::Engine;
use chrono::Utc;
use mcp_core::{Content, ResourceContents, Role, Tool, ToolCall, ToolError, ToolResult};
use serde_json::Value;

use super::super::base::Usage;
use crate::message::{Message, MessageContent};
use mcp_core::content::ImageContent;

pub fn to_bedrock_message(message: &Message) -> Result<bedrock::Message> {
bedrock::Message::builder()
Expand All @@ -31,9 +33,7 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result<bedrock::C
MessageContent::ToolConfirmationRequest(_tool_confirmation_request) => {
bedrock::ContentBlock::Text("".to_string())
}
MessageContent::Image(_) => {
bail!("Image content is not supported by Bedrock provider yet")
}
MessageContent::Image(image) => bedrock::ContentBlock::Image(to_bedrock_image(image)?),
MessageContent::Thinking(_) => {
// Thinking blocks are not supported in Bedrock - skip
bedrock::ContentBlock::Text("".to_string())
Expand Down Expand Up @@ -108,13 +108,17 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result<bedrock::C
})
}

/// Convert MCP Content to Bedrock ToolResultContentBlock
///
/// Supports text, images, and document resources. Images are supported
/// by Bedrock for Anthropic Claude 3 models.
pub fn to_bedrock_tool_result_content_block(
tool_use_id: &str,
content: &Content,
) -> Result<bedrock::ToolResultContentBlock> {
Ok(match content {
Content::Text(text) => bedrock::ToolResultContentBlock::Text(text.text.to_string()),
Content::Image(_) => bail!("Image content is not supported by Bedrock provider yet"),
Content::Image(image) => bedrock::ToolResultContentBlock::Image(to_bedrock_image(image)?),
Content::Resource(resource) => match &resource.resource {
ResourceContents::TextResourceContents { text, .. } => {
match to_bedrock_document(tool_use_id, &resource.resource)? {
Expand All @@ -136,6 +140,33 @@ pub fn to_bedrock_role(role: &Role) -> bedrock::ConversationRole {
}
}

pub fn to_bedrock_image(image: &ImageContent) -> Result<bedrock::ImageBlock> {
// Extract format from MIME type
let format = match image.mime_type.as_str() {
"image/png" => bedrock::ImageFormat::Png,
"image/jpeg" | "image/jpg" => bedrock::ImageFormat::Jpeg,
"image/gif" => bedrock::ImageFormat::Gif,
"image/webp" => bedrock::ImageFormat::Webp,
_ => bail!(
"Unsupported image format: {}. Bedrock supports png, jpeg, gif, webp",
image.mime_type
),
};

// Create image source with base64 data
let source = bedrock::ImageSource::Bytes(aws_smithy_types::Blob::new(
base64::prelude::BASE64_STANDARD
.decode(&image.data)
.map_err(|e| anyhow!("Failed to decode base64 image data: {}", e))?,
));

// Build the image block
Ok(bedrock::ImageBlock::builder()
.format(format)
.source(source)
.build()?)
}

pub fn to_bedrock_tool_config(tools: &[Tool]) -> Result<bedrock::ToolConfiguration> {
Ok(bedrock::ToolConfiguration::builder()
.set_tools(Some(
Expand Down Expand Up @@ -315,3 +346,100 @@ pub fn from_bedrock_json(document: &Document) -> Result<Value> {
),
})
}

#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use mcp_core::content::ImageContent;

// Base64 encoded 1x1 PNG image for testing
const TEST_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==";

#[test]
fn test_to_bedrock_image_supported_formats() -> Result<()> {
let supported_formats = [
"image/png",
"image/jpeg",
"image/jpg",
"image/gif",
"image/webp",
];

for mime_type in supported_formats {
let image = ImageContent {
data: TEST_IMAGE_BASE64.to_string(),
mime_type: mime_type.to_string(),
annotations: None,
};

let result = to_bedrock_image(&image);
assert!(result.is_ok(), "Failed to convert {} format", mime_type);
}

Ok(())
}

#[test]
fn test_to_bedrock_image_unsupported_format() {
let image = ImageContent {
data: TEST_IMAGE_BASE64.to_string(),
mime_type: "image/bmp".to_string(),
annotations: None,
};

let result = to_bedrock_image(&image);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("Unsupported image format: image/bmp"));
assert!(error_msg.contains("Bedrock supports png, jpeg, gif, webp"));
}

#[test]
fn test_to_bedrock_image_invalid_base64() {
let image = ImageContent {
data: "invalid_base64_data!!!".to_string(),
mime_type: "image/png".to_string(),
annotations: None,
};

let result = to_bedrock_image(&image);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("Failed to decode base64 image data"));
}

#[test]
fn test_to_bedrock_message_content_image() -> Result<()> {
let image = ImageContent {
data: TEST_IMAGE_BASE64.to_string(),
mime_type: "image/png".to_string(),
annotations: None,
};

let message_content = MessageContent::Image(image);
let result = to_bedrock_message_content(&message_content)?;

// Verify we get an Image content block
assert!(matches!(result, bedrock::ContentBlock::Image(_)));

Ok(())
}

#[test]
fn test_to_bedrock_tool_result_content_block_image() -> Result<()> {
let image = ImageContent {
data: TEST_IMAGE_BASE64.to_string(),
mime_type: "image/png".to_string(),
annotations: None,
};

let content = Content::Image(image);
let result = to_bedrock_tool_result_content_block("test_id", &content)?;

// Verify the wrapper correctly converts Content::Image to ToolResultContentBlock::Image
assert!(matches!(result, bedrock::ToolResultContentBlock::Image(_)));

Ok(())
}
}
91 changes: 91 additions & 0 deletions crates/goose/tests/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,102 @@ impl ProviderTester {
Ok(())
}

async fn test_image_content_support(&self) -> Result<()> {
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use mcp_core::content::ImageContent;
use std::fs;

// Try to read the test image
let image_path = "crates/goose/examples/test_assets/test_image.png";
let image_data = match fs::read(image_path) {
Ok(data) => data,
Err(_) => {
println!(
"Test image not found at {}, skipping image test",
image_path
);
return Ok(());
}
};

let base64_image = BASE64.encode(image_data);
let image_content = ImageContent {
data: base64_image,
mime_type: "image/png".to_string(),
annotations: None,
};

// Test 1: Direct image message
let message_with_image =
Message::user().with_image(image_content.data.clone(), image_content.mime_type.clone());

let result = self
.provider
.complete(
"You are a helpful assistant. Describe what you see in the image briefly.",
&[message_with_image],
&[],
)
.await;

println!("=== {}::image_content_support ===", self.name);
let (response, _) = result?;
println!("Image response: {:?}", response);
// Verify we got a text response
assert!(
response
.content
.iter()
.any(|content| matches!(content, MessageContent::Text(_))),
"Expected text response for image"
);
println!("===================");

// Test 2: Tool response with image (this should be handled gracefully)
let screenshot_tool = Tool::new(
"get_screenshot",
"Get a screenshot of the current screen",
serde_json::json!({
"type": "object",
"properties": {}
}),
None,
);

let user_message = Message::user().with_text("Take a screenshot please");
let tool_request = Message::assistant().with_tool_request(
"test_id",
Ok(mcp_core::tool::ToolCall::new(
"get_screenshot",
serde_json::json!({}),
)),
);
let tool_response =
Message::user().with_tool_response("test_id", Ok(vec![Content::Image(image_content)]));

let result2 = self
.provider
.complete(
"You are a helpful assistant.",
&[user_message, tool_request, tool_response],
&[screenshot_tool],
)
.await;

println!("=== {}::tool_image_response ===", self.name);
let (response, _) = result2?;
println!("Tool image response: {:?}", response);
println!("===================");

Ok(())
}

/// Run all provider tests
async fn run_test_suite(&self) -> Result<()> {
self.test_basic_response().await?;
self.test_tool_usage().await?;
self.test_context_length_exceeded_error().await?;
self.test_image_content_support().await?;
Ok(())
}
}
Expand Down