Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions python/packages/azure-ai/agent_framework_azure_ai/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Content,
HostedCodeInterpreterTool,
HostedFileSearchTool,
HostedImageGenerationTool,
HostedMCPTool,
HostedWebSearchTool,
ToolProtocol,
Expand All @@ -29,6 +30,8 @@
CodeInterpreterTool,
CodeInterpreterToolAuto,
FunctionTool,
ImageGenTool,
ImageGenToolInputImageMask,
MCPTool,
ResponseTextFormatConfigurationJsonObject,
ResponseTextFormatConfigurationJsonSchema,
Expand Down Expand Up @@ -480,6 +483,31 @@ def to_azure_ai_tools(
timezone=location.get("timezone"),
)
azure_tools.append(ws_tool)
case HostedImageGenerationTool():
opts = tool.options or {}
addl = tool.additional_properties or {}
# Azure ImageGenTool requires the constant model "gpt-image-1"
ig_tool: ImageGenTool = ImageGenTool(
model=opts.get("model_id", "gpt-image-1"), # type: ignore
size=cast(
Literal["1024x1024", "1024x1536", "1536x1024", "auto"] | None, opts.get("image_size")
),
output_format=cast(Literal["png", "webp", "jpeg"] | None, opts.get("media_type")),
input_image_mask=(
ImageGenToolInputImageMask(
image_url=addl.get("input_image_mask", {}).get("image_url"),
file_id=addl.get("input_image_mask", {}).get("file_id"),
)
if isinstance(addl.get("input_image_mask"), dict)
else None
),
quality=cast(Literal["low", "medium", "high", "auto"] | None, addl.get("quality")),
background=cast(Literal["transparent", "opaque", "auto"] | None, addl.get("background")),
output_compression=cast(int | None, addl.get("output_compression")),
moderation=cast(Literal["auto", "low"] | None, addl.get("moderation")),
partial_images=opts.get("streaming_count"),
)
azure_tools.append(ig_tool)
case _:
logger.debug("Unsupported tool passed (type: %s)", type(tool))
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import base64
import tempfile
from pathlib import Path
from urllib import request as urllib_request

import aiofiles
from agent_framework import DataContent, HostedImageGenerationTool
from agent_framework import HostedImageGenerationTool
from agent_framework.azure import AzureAIProjectAgentProvider
from azure.identity.aio import AzureCliCredential

Expand Down Expand Up @@ -32,10 +35,14 @@ async def main() -> None:
tools=[
HostedImageGenerationTool(
options={
"model": "gpt-image-1-mini",
"model_id": "gpt-image-1",
"image_size": "1024x1024",
"media_type": "png",
},
additional_properties={
"quality": "low",
"size": "1024x1024",
}
"background": "opaque",
},
)
],
)
Expand All @@ -54,17 +61,36 @@ async def main() -> None:
# Save the image to a file
print("Downloading generated image...")
image_data = [
content
content.outputs
for content in result.messages[0].contents
if isinstance(content, DataContent) and content.media_type == "image/png"
if content.type == "image_generation_tool_result" and content.outputs is not None
]
if image_data and image_data[0]:
# Save to the same directory as this script
# Save to the OS temporary directory
filename = "microsoft.png"
current_dir = Path(__file__).parent.resolve()
file_path = current_dir / filename
file_path = Path(tempfile.gettempdir()) / filename
# outputs can be a list of Content items (data/uri) or a single item
out = image_data[0][0] if isinstance(image_data[0], list) else image_data[0]
data_bytes: bytes | None = None
uri = getattr(out, "uri", None)
if isinstance(uri, str):
if ";base64," in uri:
try:
b64 = uri.split(";base64,", 1)[1]
data_bytes = base64.b64decode(b64)
except Exception:
data_bytes = None
else:
try:
data_bytes = await asyncio.to_thread(lambda: urllib_request.urlopen(uri).read())
except Exception:
data_bytes = None

if data_bytes is None:
raise RuntimeError("Image output present but could not retrieve bytes.")

async with aiofiles.open(file_path, "wb") as f:
await f.write(image_data[0].get_data_bytes())
await f.write(data_bytes)

print(f"Image downloaded and saved to: {file_path}")
else:
Expand Down
Loading