Skip to content

Commit aa4aaf5

Browse files
author
jer
committed
feat(a2a): support A2A FileParts and DataParts
1 parent c85464c commit aa4aaf5

File tree

2 files changed

+949
-24
lines changed

2 files changed

+949
-24
lines changed

src/strands/multiagent/a2a/executor.py

Lines changed: 191 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,28 @@
88
streamed requests to the A2AServer.
99
"""
1010

11+
import json
1112
import logging
12-
from typing import Any
13+
from typing import Any, Literal
1314

1415
from a2a.server.agent_execution import AgentExecutor, RequestContext
1516
from a2a.server.events import EventQueue
1617
from a2a.server.tasks import TaskUpdater
17-
from a2a.types import InternalError, Part, TaskState, TextPart, UnsupportedOperationError
18+
from a2a.types import DataPart, FilePart, InternalError, Part, TaskState, TextPart, UnsupportedOperationError
1819
from a2a.utils import new_agent_text_message, new_task
1920
from a2a.utils.errors import ServerError
2021

2122
from ...agent.agent import Agent as SAAgent
2223
from ...agent.agent import AgentResult as SAAgentResult
24+
from ...types.content import ContentBlock
25+
from ...types.media import (
26+
DocumentContent,
27+
DocumentSource,
28+
ImageContent,
29+
ImageSource,
30+
VideoContent,
31+
VideoSource,
32+
)
2333

2434
logger = logging.getLogger(__name__)
2535

@@ -31,6 +41,26 @@ class StrandsA2AExecutor(AgentExecutor):
3141
and converts Strands Agent responses to A2A protocol events.
3242
"""
3343

44+
# File format mappings for different content types
45+
IMAGE_FORMAT_MAPPINGS = {"jpeg": "jpeg", "jpg": "jpeg", "png": "png", "gif": "gif", "webp": "webp"}
46+
47+
VIDEO_FORMAT_MAPPINGS = {
48+
"mp4": "mp4",
49+
"mpeg": "mpeg",
50+
"mpg": "mpg",
51+
"webm": "webm",
52+
"mov": "mov",
53+
"mkv": "mkv",
54+
"flv": "flv",
55+
"wmv": "wmv",
56+
"3gpp": "three_gp",
57+
}
58+
59+
DOCUMENT_FORMAT_MAPPINGS = {"pdf": "pdf", "csv": "csv", "html": "html", "plain": "txt", "markdown": "md"}
60+
61+
# Default formats for each file type when MIME type is unavailable
62+
DEFAULT_FORMATS = {"document": "txt", "image": "png", "video": "mp4", "unknown": "txt"}
63+
3464
def __init__(self, agent: SAAgent):
3565
"""Initialize a StrandsA2AExecutor.
3666
@@ -78,10 +108,15 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater
78108
context: The A2A request context, containing the user's input and other metadata.
79109
updater: The task updater for managing task state and sending updates.
80110
"""
81-
logger.info("Executing request in streaming mode")
82-
user_input = context.get_user_input()
111+
# Convert A2A message parts to Strands ContentBlocks
112+
if context.message and hasattr(context.message, "parts"):
113+
content_blocks = self._convert_a2a_parts_to_content_blocks(context.message.parts)
114+
else:
115+
# Fallback to original text extraction if no parts available
116+
user_input = context.get_user_input()
117+
content_blocks = [ContentBlock(text=user_input)]
83118
try:
84-
async for event in self.agent.stream_async(user_input):
119+
async for event in self.agent.stream_async(content_blocks):
85120
await self._handle_streaming_event(event, updater)
86121
except Exception:
87122
logger.exception("Error in streaming execution")
@@ -146,3 +181,154 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None
146181
"""
147182
logger.warning("Cancellation requested but not supported")
148183
raise ServerError(error=UnsupportedOperationError())
184+
185+
def _get_file_type_from_mime_type(self, mime_type: str | None) -> Literal["document", "image", "video", "unknown"]:
186+
"""Classify file type based on MIME type.
187+
188+
Args:
189+
mime_type: The MIME type of the file
190+
191+
Returns:
192+
The classified file type
193+
"""
194+
if not mime_type:
195+
return "unknown"
196+
197+
mime_type = mime_type.lower()
198+
199+
if mime_type.startswith("image/"):
200+
return "image"
201+
elif mime_type.startswith("video/"):
202+
return "video"
203+
elif (
204+
mime_type.startswith("text/")
205+
or mime_type.startswith("application/")
206+
or mime_type in ["application/pdf", "application/json", "application/xml"]
207+
):
208+
return "document"
209+
else:
210+
return "unknown"
211+
212+
def _get_file_format_from_mime_type(self, mime_type: str | None, file_type: str) -> str:
213+
"""Extract file format from MIME type.
214+
215+
Args:
216+
mime_type: The MIME type of the file
217+
file_type: The classified file type (image, video, document, txt)
218+
219+
Returns:
220+
The file format string
221+
"""
222+
if not mime_type:
223+
return self.DEFAULT_FORMATS.get(file_type, "txt")
224+
225+
mime_type = mime_type.lower()
226+
227+
# Extract format from MIME type
228+
if "/" in mime_type:
229+
format_part = mime_type.split("/")[1]
230+
231+
# Handle common MIME type mappings with validation
232+
if file_type == "image":
233+
return self.IMAGE_FORMAT_MAPPINGS.get(format_part, "png")
234+
elif file_type == "video":
235+
return self.VIDEO_FORMAT_MAPPINGS.get(format_part, "mp4")
236+
else: # document
237+
return self.DOCUMENT_FORMAT_MAPPINGS.get(format_part, "txt")
238+
239+
# Fallback defaults
240+
return self.DEFAULT_FORMATS.get(file_type, "txt")
241+
242+
def _strip_file_extension(self, file_name: str) -> str:
243+
"""Strip the file extension from a file name.
244+
245+
Args:
246+
file_name: The original file name with extension
247+
248+
Returns:
249+
The file name without extension
250+
"""
251+
if "." in file_name:
252+
return file_name.rsplit(".", 1)[0]
253+
return file_name
254+
255+
def _convert_a2a_parts_to_content_blocks(self, parts: list[Part]) -> list[ContentBlock]:
256+
"""Convert A2A message parts to Strands ContentBlocks.
257+
258+
Args:
259+
parts: List of A2A Part objects
260+
261+
Returns:
262+
List of Strands ContentBlock objects
263+
"""
264+
content_blocks: list[ContentBlock] = []
265+
266+
for part in parts:
267+
try:
268+
part_root = part.root
269+
270+
if isinstance(part_root, TextPart):
271+
# Handle TextPart
272+
content_blocks.append(ContentBlock(text=part_root.text))
273+
274+
elif isinstance(part_root, FilePart):
275+
# Handle FilePart
276+
file_obj = part_root.file
277+
mime_type = getattr(file_obj, "mime_type", None)
278+
raw_file_name = getattr(file_obj, "name", "FileNameNotProvided")
279+
file_name = self._strip_file_extension(raw_file_name)
280+
file_type = self._get_file_type_from_mime_type(mime_type)
281+
file_format = self._get_file_format_from_mime_type(mime_type, file_type)
282+
283+
# Handle FileWithBytes vs FileWithUri
284+
bytes_data = getattr(file_obj, "bytes", None)
285+
uri_data = getattr(file_obj, "uri", None)
286+
287+
if bytes_data:
288+
if file_type == "image":
289+
content_blocks.append(
290+
ContentBlock(
291+
image=ImageContent(
292+
format=file_format, # type: ignore
293+
source=ImageSource(bytes=bytes_data),
294+
)
295+
)
296+
)
297+
elif file_type == "video":
298+
content_blocks.append(
299+
ContentBlock(
300+
video=VideoContent(
301+
format=file_format, # type: ignore
302+
source=VideoSource(bytes=bytes_data),
303+
)
304+
)
305+
)
306+
else: # document or unknown
307+
content_blocks.append(
308+
ContentBlock(
309+
document=DocumentContent(
310+
format=file_format, # type: ignore
311+
name=file_name,
312+
source=DocumentSource(bytes=bytes_data),
313+
)
314+
)
315+
)
316+
# Handle FileWithUri
317+
elif uri_data:
318+
# For URI files, create a text representation since Strands ContentBlocks expect bytes
319+
content_blocks.append(
320+
ContentBlock(
321+
text="[File: %s (%s)] - Referenced file at: %s" % (file_name, mime_type, uri_data)
322+
)
323+
)
324+
elif isinstance(part_root, DataPart):
325+
# Handle DataPart - convert structured data to JSON text
326+
try:
327+
data_text = json.dumps(part_root.data, indent=2)
328+
content_blocks.append(ContentBlock(text="[Structured Data]\n%s" % data_text))
329+
except Exception:
330+
logger.exception("Failed to serialize data part")
331+
except Exception:
332+
logger.exception("Error processing part")
333+
334+
return content_blocks

0 commit comments

Comments
 (0)