Skip to content
Merged
30 changes: 28 additions & 2 deletions python/packages/a2a/agent_framework_a2a/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,11 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride]
else:
if not normalized_messages:
raise ValueError("At least one message is required when starting a new task (no continuation_token).")
previous_task_id = session.state.get("a2a_task_id") if session else None
a2a_message = self._prepare_message_for_a2a(
normalized_messages[-1],
context_id=session.service_session_id if session else None,
previous_task_id=previous_task_id,
)
a2a_stream = self.client.send_message(SendMessageRequest(message=a2a_message))

Expand Down Expand Up @@ -363,6 +365,7 @@ async def _map_a2a_stream(

all_updates: list[AgentResponseUpdate] = []
streamed_artifact_ids_by_task: dict[str, set[str]] = {}
last_task_id: str | None = None
# In non-streaming mode, accumulate intermediate status content so it
# can be surfaced when the terminal event arrives (mirroring v0.3.x
# behavior where the full Task history was available at completion).
Expand All @@ -372,6 +375,8 @@ async def _map_a2a_stream(
if payload_type == "message":
# Process A2A Message
msg = item.message
if msg.task_id:
last_task_id = msg.task_id
Comment thread
giles17 marked this conversation as resolved.
Outdated
contents = self._parse_contents_from_a2a(msg.parts)
metadata = MessageToDict(msg.metadata) if msg.metadata else None
update = AgentResponseUpdate(
Expand All @@ -385,6 +390,7 @@ async def _map_a2a_stream(
yield update
elif payload_type == "task":
task = item.task
last_task_id = task.id
updates = self._updates_from_task(
task,
background=background,
Expand All @@ -406,6 +412,7 @@ async def _map_a2a_stream(
yield update
elif payload_type == "status_update":
status_event = item.status_update
last_task_id = status_event.task_id or last_task_id
Comment thread
giles17 marked this conversation as resolved.
Outdated
updates = self._updates_from_task_update_event(status_event)
is_terminal = status_event.status.state in TERMINAL_TASK_STATES
if emit_intermediate:
Expand All @@ -431,6 +438,7 @@ async def _map_a2a_stream(
pending_updates_by_task.setdefault(status_event.task_id, []).extend(updates)
elif payload_type == "artifact_update":
artifact_event = item.artifact_update
last_task_id = artifact_event.task_id or last_task_id
updates = self._updates_from_task_update_event(artifact_event)
# Always yield artifact updates — they carry actual response
# content (files, data). Track IDs so that a subsequent
Expand All @@ -449,6 +457,11 @@ async def _map_a2a_stream(
if all_updates:
session_context._response = AgentResponse.from_updates(all_updates) # type: ignore[assignment]

# Persist the last task_id on the session so follow-up messages can
# reference it via reference_task_ids (task refinements).
if session is not None and last_task_id:
session.state["a2a_task_id"] = last_task_id

await self._run_after_providers(session=session, context=session_context)

# ------------------------------------------------------------------
Expand Down Expand Up @@ -618,7 +631,9 @@ async def poll_task(self, continuation_token: A2AContinuationToken) -> AgentResp
return AgentResponse.from_updates(updates)
return AgentResponse(messages=[], response_id=task.id, raw_representation=task)

def _prepare_message_for_a2a(self, message: Message, *, context_id: str | None = None) -> A2AMessage:
def _prepare_message_for_a2a(
self, message: Message, *, context_id: str | None = None, previous_task_id: str | None = None
) -> A2AMessage:
"""Prepare a Message for the A2A protocol.

Transforms Agent Framework Message objects into A2A protocol Messages by:
Expand All @@ -627,13 +642,19 @@ def _prepare_message_for_a2a(self, message: Message, *, context_id: str | None =
- Converting file references (URI/data/hosted_file) to FilePart objects
- Preserving metadata and additional properties from the original message
- Setting the role to 'user' as framework messages are treated as user input
- Linking follow-up messages to previous tasks via reference_task_ids

Args:
message: The framework Message to convert.

Keyword Args:
context_id: Optional fallback context identifier (e.g. derived from
``AgentSession.service_session_id``). When the *message* already
carries a ``context_id`` in its ``additional_properties`` that
value takes precedence; otherwise this fallback is used.
previous_task_id: Optional task ID from a previous interaction. When
provided, the message is linked as a follow-up (task refinement)
via ``reference_task_ids``.
"""
parts: list[A2APart] = []
if not message.contents:
Expand Down Expand Up @@ -693,14 +714,19 @@ def _prepare_message_for_a2a(self, message: Message, *, context_id: str | None =

a2a_metadata = message.additional_properties.get("a2a_metadata")

return A2AMessage(
a2a_message = A2AMessage(
role=A2ARole.ROLE_USER,
parts=parts,
message_id=message.message_id or uuid.uuid4().hex,
context_id=message.additional_properties.get("context_id") or context_id,
Comment thread
giles17 marked this conversation as resolved.
Outdated
metadata=a2a_metadata or {},
)

if previous_task_id:
a2a_message.reference_task_ids.append(previous_task_id)

Comment thread
giles17 marked this conversation as resolved.
return a2a_message

def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Content]:
"""Parse A2A Parts into Agent Framework Content.

Expand Down
122 changes: 122 additions & 0 deletions python/packages/a2a/tests/test_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,3 +1669,125 @@ async def test_non_streaming_artifact_update_surfaces_content(


# endregion


# region Reference Task IDs Tests


@mark.asyncio
async def test_first_message_has_no_reference_task_ids(mock_a2a_client: MockA2AClient) -> None:
"""Test that the first message sent has no reference_task_ids."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
mock_a2a_client.add_task_response("task-first", [{"content": "Hello back"}])

session = AgentSession()
await agent.run("Hello", session=session)

assert mock_a2a_client.last_message is not None
assert list(mock_a2a_client.last_message.reference_task_ids) == []


@mark.asyncio
async def test_follow_up_message_includes_reference_task_ids(mock_a2a_client: MockA2AClient) -> None:
"""Test that a follow-up message references the previous task_id."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
mock_a2a_client.add_task_response("task-abc-123", [{"content": "First reply"}])

session = AgentSession()
await agent.run("Hello", session=session)

# Verify task_id was persisted on session
assert session.state.get("a2a_task_id") == "task-abc-123"

# Send a follow-up message
mock_a2a_client.add_task_response("task-def-456", [{"content": "Second reply"}])
await agent.run("Follow up", session=session)

assert mock_a2a_client.last_message is not None
assert list(mock_a2a_client.last_message.reference_task_ids) == ["task-abc-123"]


@mark.asyncio
async def test_reference_task_ids_updated_after_each_interaction(mock_a2a_client: MockA2AClient) -> None:
"""Test that reference_task_ids always points to the most recent task."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)

session = AgentSession()

# First interaction
mock_a2a_client.add_task_response("task-1", [{"content": "Reply 1"}])
await agent.run("Message 1", session=session)
assert session.state["a2a_task_id"] == "task-1"

# Second interaction
mock_a2a_client.add_task_response("task-2", [{"content": "Reply 2"}])
await agent.run("Message 2", session=session)
assert mock_a2a_client.last_message.reference_task_ids == ["task-1"]
assert session.state["a2a_task_id"] == "task-2"

# Third interaction references the second task
mock_a2a_client.add_task_response("task-3", [{"content": "Reply 3"}])
await agent.run("Message 3", session=session)
assert mock_a2a_client.last_message.reference_task_ids == ["task-2"]
assert session.state["a2a_task_id"] == "task-3"


@mark.asyncio
async def test_task_id_tracked_from_status_update_events(mock_a2a_client: MockA2AClient) -> None:
"""Test that task_id is tracked even when response only contains status update events."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)

# Simulate a stream that only has status_update events (no full task payload)
status_event = TaskStatusUpdateEvent(
task_id="task-from-status",
context_id="ctx-1",
status=TaskStatus(
state=TaskState.TASK_STATE_COMPLETED,
message=A2AMessage(
message_id="msg-status",
role=A2ARole.ROLE_AGENT,
parts=[Part(text="Done")],
),
),
)
mock_a2a_client.responses.append(StreamResponse(status_update=status_event))

session = AgentSession()
await agent.run("Hello", session=session)

assert session.state.get("a2a_task_id") == "task-from-status"


@mark.asyncio
async def test_no_session_does_not_crash_reference_task_ids(mock_a2a_client: MockA2AClient) -> None:
"""Test that running without a session (no reference tracking) works fine."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
mock_a2a_client.add_task_response("task-no-session", [{"content": "Reply"}])

# Should not raise — no session means no reference_task_ids
response = await agent.run("Hello")
assert response is not None
assert mock_a2a_client.last_message.reference_task_ids == []


@mark.asyncio
async def test_task_id_tracked_from_message_payload(mock_a2a_client: MockA2AClient) -> None:
"""Test that task_id is tracked from message payloads that include a task_id."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)

# Simulate a response that is a message with task_id set (no task/status_update events)
message_with_task = A2AMessage(
message_id="msg-with-task",
role=A2ARole.ROLE_AGENT,
parts=[Part(text="Response")],
task_id="task-from-message",
)
mock_a2a_client.responses.append(StreamResponse(message=message_with_task))

session = AgentSession()
await agent.run("Hello", session=session)

assert session.state.get("a2a_task_id") == "task-from-message"


# endregion
Loading