diff --git a/src/strands_tools/generate_image.py b/src/strands_tools/generate_image.py index ea8c2a16..0fd30f35 100644 --- a/src/strands_tools/generate_image.py +++ b/src/strands_tools/generate_image.py @@ -9,10 +9,12 @@ Key Features: 1. Image Generation: - • Text-to-image conversion using Stable Diffusion - • Support for multiple model variants (primarily stable-diffusion-xl-v1) - • Customizable generation parameters (seed, steps, cfg_scale) - • Style preset selection for consistent aesthetics + • Text-to-image conversion using Stable Diffusion models + • Support for the following models: + • stability.sd3-5-large-v1:0 + • stability.stable-image-core-v1:1 + • stability.stable-image-ultra-v1:1 + • Customizable generation parameters (seed, aspect_ratio, output_format, negative_prompt) 2. Output Management: • Automatic local saving with intelligent filename generation @@ -36,14 +38,22 @@ # Basic usage with default parameters agent.tool.generate_image(prompt="A steampunk robot playing chess") -# Advanced usage with custom parameters +# Advanced usage with Stable Diffusion agent.tool.generate_image( prompt="A futuristic city with flying cars", - model_id="stability.stable-diffusion-xl-v1", - seed=42, - steps=50, - cfg_scale=12, - style_preset="cinematic" + model_id="stability.sd3-5-large-v1:0", + aspect_ratio="5:4", + output_format="jpeg", + negative_prompt="bad lighting, harsh lighting, abstract, surreal, twisted, multiple levels", +) + +# Using another Stable Diffusion model +agent.tool.generate_image( + prompt="A photograph of a cup of coffee from the side", + model_id="stability.stable-image-ultra-v1:1", + aspect_ratio="1:1", + output_format="png", + negative_prompt="blurry, distorted", ) ``` @@ -60,9 +70,16 @@ import boto3 from strands.types.tools import ToolResult, ToolUse +STABLE_DIFFUSION_MODEL_ID = [ + "stability.sd3-5-large-v1:0", + "stability.stable-image-core-v1:1", + "stability.stable-image-ultra-v1:1", +] + + TOOL_SPEC = { "name": "generate_image", - "description": "Generates an image using Stable Diffusion based on a given prompt", + "description": "Generates an image using Stable Diffusion models based on a given prompt", "inputSchema": { "json": { "type": "object", @@ -73,23 +90,32 @@ }, "model_id": { "type": "string", - "description": "Model id for image model, stability.stable-diffusion-xl-v1.", + "description": "Model id for image model, stability.sd3-5-large-v1:0, \ + stability.stable-image-core-v1:1, or stability.stable-image-ultra-v1:1", + }, + "region": { + "type": "string", + "description": "AWS region for the image generation model (default: us-west-2)", }, "seed": { "type": "integer", "description": "Optional: Seed for random number generation (default: random)", }, - "steps": { - "type": "integer", - "description": "Optional: Number of steps for image generation (default: 30)", + "aspect_ratio": { + "type": "string", + "description": "Optional: Controls the aspect ratio of the generated image for \ + Stable Diffusion models. Default 1:1. Enum: 16:9, 1:1, 21:9, 2:3, 3:2, 4:5, 5:4, 9:16, 9:21", }, - "cfg_scale": { - "type": "number", - "description": "Optional: CFG scale for image generation (default: 10)", + "output_format": { + "type": "string", + "description": "Optional: Specifies the format of the output image for Stable Diffusion models. \ + Supported formats: JPEG, PNG.", }, - "style_preset": { + "negative_prompt": { "type": "string", - "description": "Optional: Style preset for image generation (default: 'photographic')", + "description": "Optional: Keywords of what you do not wish to see in the output image. \ + Default: bad lighting, harsh lighting. \ + Max: 10.000 characters.", }, }, "required": ["prompt"], @@ -98,19 +124,28 @@ } +# Create a filename based on the prompt +def create_filename(prompt: str) -> str: + """Generate a filename from the prompt text.""" + words = re.findall(r"\w+", prompt.lower())[:5] + filename = "_".join(words) + filename = re.sub(r"[^\w\-_\.]", "_", filename) + return filename[:100] # Limit filename length + + def generate_image(tool: ToolUse, **kwargs: Any) -> ToolResult: """ - Generate images from text prompts using Stable Diffusion via Amazon Bedrock. + Generate images from text prompts using Stable Diffusion models via Amazon Bedrock. This function transforms textual descriptions into high-quality images using - Stable Diffusion models available through Amazon Bedrock. It provides extensive + image generation models available through Amazon Bedrock. It provides extensive customization options and handles the complete process from API interaction to image storage and result formatting. How It Works: ------------ 1. Extracts and validates parameters from the tool input - 2. Configures the request payload with appropriate parameters + 2. Configures the request payload with appropriate parameters based on model type 3. Invokes the Bedrock image generation model through AWS SDK 4. Processes the response to extract the base64-encoded image 5. Creates an appropriate filename based on the prompt content @@ -120,11 +155,13 @@ def generate_image(tool: ToolUse, **kwargs: Any) -> ToolResult: Generation Parameters: -------------------- - prompt: The textual description of the desired image - - model_id: Specific model to use (defaults to stable-diffusion-xl-v1) + - model_id: Specific model to use (defaults to stability.stable-image-core-v1:1) - seed: Controls randomness for reproducible results - - style_preset: Artistic style to apply (e.g., photographic, cinematic) - - cfg_scale: Controls how closely the image follows the prompt - - steps: Number of diffusion steps (higher = more refined but slower) + - aspect_ratio: Controls the aspect ratio of the generated image + - output_format: Specifies the format of the output image (e.g., png or jpeg) + - negative_prompt: Keywords of what you do not wish to see in the output image + + Common Usage Scenarios: --------------------- @@ -137,11 +174,8 @@ def generate_image(tool: ToolUse, **kwargs: Any) -> ToolResult: Args: tool: ToolUse object containing the parameters for image generation. - prompt: The text prompt describing the desired image. - - model_id: Optional model identifier (default: "stability.stable-diffusion-xl-v1"). - - seed: Optional random seed (default: random integer). - - style_preset: Optional style preset name (default: "photographic"). - - cfg_scale: Optional CFG scale value (default: 10). - - steps: Optional number of diffusion steps (default: 30). + - model_id: Optional model identifier. + - Additional parameters specific to the chosen model type. **kwargs: Additional keyword arguments (unused). Returns: @@ -161,24 +195,28 @@ def generate_image(tool: ToolUse, **kwargs: Any) -> ToolResult: tool_use_id = tool["toolUseId"] tool_input = tool["input"] - # Extract input parameters + # Extract common and Stable Diffusion input parameters + aspect_ratio = tool_input.get("aspect_ratio", "1:1") + output_format = tool_input.get("output_format", "jpeg") prompt = tool_input.get("prompt", "A stylized picture of a cute old steampunk robot.") - model_id = tool_input.get("model_id", "stability.stable-diffusion-xl-v1") + model_id = tool_input.get("model_id", "stability.stable-image-core-v1:1") + region = tool_input.get("region", "us-west-2") seed = tool_input.get("seed", random.randint(0, 4294967295)) - style_preset = tool_input.get("style_preset", "photographic") - cfg_scale = tool_input.get("cfg_scale", 10) - steps = tool_input.get("steps", 30) + negative_prompt = tool_input.get("negative_prompt", "bad lighting, harsh lighting") # Create a Bedrock Runtime client - client = boto3.client("bedrock-runtime", region_name="us-west-2") + client = boto3.client("bedrock-runtime", region_name=region) + + # Initialize variables for later use + base64_image_data = None - # Format the request payload + # create the request body native_request = { - "text_prompts": [{"text": prompt}], - "style_preset": style_preset, + "prompt": prompt, + "aspect_ratio": aspect_ratio, "seed": seed, - "cfg_scale": cfg_scale, - "steps": steps, + "output_format": output_format, + "negative_prompt": negative_prompt, } request = json.dumps(native_request) @@ -186,53 +224,55 @@ def generate_image(tool: ToolUse, **kwargs: Any) -> ToolResult: response = client.invoke_model(modelId=model_id, body=request) # Decode the response body - model_response = json.loads(response["body"].read()) + model_response = json.loads(response["body"].read().decode("utf-8")) # Extract the image data - base64_image_data = model_response["artifacts"][0]["base64"] - - # Create a filename based on the prompt - def create_filename(prompt: str) -> str: - """Generate a filename from the prompt text.""" - words = re.findall(r"\w+", prompt.lower())[:5] - filename = "_".join(words) - filename = re.sub(r"[^\w\-_\.]", "_", filename) - return filename[:100] # Limit filename length - - filename = create_filename(prompt) - - # Save the generated image to a local folder - output_dir = "output" - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - i = 1 - base_image_path = os.path.join(output_dir, f"{filename}.png") - image_path = base_image_path - while os.path.exists(image_path): - image_path = os.path.join(output_dir, f"{filename}_{i}.png") - i += 1 - - with open(image_path, "wb") as file: - file.write(base64.b64decode(base64_image_data)) - + base64_image_data = model_response["images"][0] + + # If we have image data, process and save it + if base64_image_data: + filename = create_filename(prompt) + + # Save the generated image to a local folder + output_dir = "output" + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + i = 1 + base_image_path = os.path.join(output_dir, f"{filename}.png") + image_path = base_image_path + while os.path.exists(image_path): + image_path = os.path.join(output_dir, f"{filename}_{i}.png") + i += 1 + + with open(image_path, "wb") as file: + file.write(base64.b64decode(base64_image_data)) + + return { + "toolUseId": tool_use_id, + "status": "success", + "content": [ + {"text": f"The generated image has been saved locally to {image_path}. "}, + { + "image": { + "format": output_format, + "source": {"bytes": base64.b64decode(base64_image_data)}, + } + }, + ], + } + else: + raise Exception("No image data found in the response.") + except Exception as e: return { "toolUseId": tool_use_id, - "status": "success", + "status": "error", "content": [ - {"text": f"The generated image has been saved locally to {image_path}. "}, { - "image": { - "format": "png", - "source": {"bytes": base64.b64decode(base64_image_data)}, - } - }, + "text": f"Error generating image: {str(e)} \n Try other supported models for this tool are: \n \ + 1. stability.sd3-5-large-v1:0 \n \ + 2. stability.stable-image-core-v1:1 \n \ + 3. stability.stable-image-ultra-v1:1" + } ], } - - except Exception as e: - return { - "toolUseId": tool_use_id, - "status": "error", - "content": [{"text": f"Error generating image: {str(e)}"}], - } diff --git a/src/strands_tools/mem0_memory.py b/src/strands_tools/mem0_memory.py index d9849814..5840deaa 100644 --- a/src/strands_tools/mem0_memory.py +++ b/src/strands_tools/mem0_memory.py @@ -140,7 +140,7 @@ "description": "Optional metadata to store with the memory", }, }, - "required": ["action"] + "required": ["action"], } }, } diff --git a/tests/test_generate_image.py b/tests/test_generate_image.py index dc2132f6..caf93bb9 100644 --- a/tests/test_generate_image.py +++ b/tests/test_generate_image.py @@ -31,7 +31,7 @@ def mock_boto3_client(): # Set up mock response mock_body = MagicMock() mock_body.read.return_value = json.dumps( - {"artifacts": [{"base64": base64.b64encode(b"mock_image_data").decode("utf-8")}]} + {"images": [base64.b64encode(b"mock_image_data").decode("utf-8")]} ).encode("utf-8") mock_client_instance = MagicMock() @@ -76,9 +76,9 @@ def test_generate_image_direct(mock_boto3_client, mock_os_path_exists, mock_os_m "input": { "prompt": "A cute robot", "seed": 123, - "steps": 30, - "cfg_scale": 10, - "style_preset": "photographic", + "aspect_ratio": "5:4", + "output_format": "png", + "negative_prompt": "blurry, low resolution, pixelated, grainy, unrealistic", }, } @@ -94,11 +94,11 @@ def test_generate_image_direct(mock_boto3_client, mock_os_path_exists, mock_os_m args, kwargs = mock_client_instance.invoke_model.call_args request_body = json.loads(kwargs["body"]) - assert request_body["text_prompts"][0]["text"] == "A cute robot" + assert request_body["prompt"] == "A cute robot" assert request_body["seed"] == 123 - assert request_body["steps"] == 30 - assert request_body["cfg_scale"] == 10 - assert request_body["style_preset"] == "photographic" + assert request_body["aspect_ratio"] == "5:4" + assert request_body["output_format"] == "png" + assert request_body["negative_prompt"] == "blurry, low resolution, pixelated, grainy, unrealistic" # Verify directory creation mock_os_makedirs.assert_called_once() @@ -128,9 +128,9 @@ def test_generate_image_default_params(mock_boto3_client, mock_os_path_exists, m request_body = json.loads(kwargs["body"]) assert request_body["seed"] == 42 # From our mocked random.randint - assert request_body["steps"] == 30 - assert request_body["cfg_scale"] == 10 - assert request_body["style_preset"] == "photographic" + assert request_body["aspect_ratio"] == "1:1" + assert request_body["output_format"] == "jpeg" + assert request_body["negative_prompt"] == "bad lighting, harsh lighting" assert result["status"] == "success"