Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,226 +177,224 @@
async for sse in event_source.aiter_sse():
response = SendStreamingMessageResponse.model_validate(
json.loads(sse.data)
)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
yield response.root.result
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_request(
self,
rpc_request_payload: dict[str, Any],
http_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
try:
response = await self.httpx_client.post(
self.url, json=rpc_request_payload, **(http_kwargs or {})
)
response.raise_for_status()
return response.json()
except httpx.ReadTimeout as e:
raise A2AClientTimeoutError('Client Request timed out') from e
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def get_task(
self,
request: TaskQueryParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""
rpc_request = GetTaskRequest(params=request, id=str(uuid4()))
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
'tasks/get',
rpc_request.model_dump(mode='json', exclude_none=True),
modified_kwargs,
context,
)
response_data = await self._send_request(payload, modified_kwargs)
response = GetTaskResponse.model_validate(response_data)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
return response.root.result

async def cancel_task(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task."""
rpc_request = CancelTaskRequest(params=request, id=str(uuid4()))
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
'tasks/cancel',
rpc_request.model_dump(mode='json', exclude_none=True),
modified_kwargs,
context,
)
response_data = await self._send_request(payload, modified_kwargs)
response = CancelTaskResponse.model_validate(response_data)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
return response.root.result

async def set_task_callback(
self,
request: TaskPushNotificationConfig,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task."""
rpc_request = SetTaskPushNotificationConfigRequest(
params=request, id=str(uuid4())
)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
'tasks/pushNotificationConfig/set',
rpc_request.model_dump(mode='json', exclude_none=True),
modified_kwargs,
context,
)
response_data = await self._send_request(payload, modified_kwargs)
response = SetTaskPushNotificationConfigResponse.model_validate(
response_data
)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
return response.root.result

async def get_task_callback(
self,
request: GetTaskPushNotificationConfigParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task."""
rpc_request = GetTaskPushNotificationConfigRequest(
params=request, id=str(uuid4())
)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
'tasks/pushNotificationConfig/get',
rpc_request.model_dump(mode='json', exclude_none=True),
modified_kwargs,
context,
)
response_data = await self._send_request(payload, modified_kwargs)
response = GetTaskPushNotificationConfigResponse.model_validate(
response_data
)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
return response.root.result

async def resubscribe(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AsyncGenerator[
Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
]:
"""Reconnects to get task updates."""
rpc_request = TaskResubscriptionRequest(params=request, id=str(uuid4()))
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
'tasks/resubscribe',
rpc_request.model_dump(mode='json', exclude_none=True),
modified_kwargs,
context,
)
modified_kwargs.setdefault('timeout', None)

async with aconnect_sse(
self.httpx_client,
'POST',
self.url,
json=payload,
**modified_kwargs,
) as event_source:
try:
async for sse in event_source.aiter_sse():
response = SendStreamingMessageResponse.model_validate_json(
sse.data
)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
yield response.root.result
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
card = self.agent_card
if not card:
resolver = A2ACardResolver(self.httpx_client, self.url)
card = await resolver.get_agent_card(
http_kwargs=self._get_http_args(context)
)
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
self._needs_extended_card = (

Check notice on line 389 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (359-397)
card.supports_authenticated_extended_card
)
self.agent_card = card

if not self._needs_extended_card:
return card

request = GetAuthenticatedExtendedCardRequest(id=str(uuid4()))

Check notice on line 397 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (180-389)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
request.method,
request.model_dump(mode='json', exclude_none=True),
Expand Down
12 changes: 5 additions & 7 deletions src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,14 @@ async def get_card(
extensions: list[str] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
card = self.agent_card
if not card:
resolver = A2ACardResolver(self.httpx_client, self.url)
card = await resolver.get_agent_card(
http_kwargs=self._get_http_args(context)
)
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
self._needs_extended_card = (
card.supports_authenticated_extended_card
)
Expand All @@ -384,10 +386,6 @@ async def get_card(
if not self._needs_extended_card:
return card

modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
_, modified_kwargs = await self._apply_interceptors(
{},
modified_kwargs,
Expand Down
84 changes: 84 additions & 0 deletions tests/client/transports/test_jsonrpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,3 +875,87 @@ async def test_send_message_streaming_with_new_extensions(
assert (
headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2'
)

@pytest.mark.asyncio
async def test_get_card_no_card_provided_with_extensions(
self, mock_httpx_client: AsyncMock
):
"""Test get_card with extensions set in Client when no card is initially provided.
Tests that the extensions are added to the HTTP GET request."""
extensions = [
'https://example.com/test-ext/v1',
'https://example.com/test-ext/v2',
]
client = JsonRpcTransport(
httpx_client=mock_httpx_client,
url=TestJsonRpcTransport.AGENT_URL,
extensions=extensions,
)
mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = AGENT_CARD.model_dump(mode='json')
mock_httpx_client.get.return_value = mock_response

await client.get_card()

mock_httpx_client.get.assert_called_once()
_, mock_kwargs = mock_httpx_client.get.call_args

headers = mock_kwargs.get('headers', {})
assert HTTP_EXTENSION_HEADER in headers
header_value = headers[HTTP_EXTENSION_HEADER]
actual_extensions_list = [e.strip() for e in header_value.split(',')]
actual_extensions = set(actual_extensions_list)

expected_extensions = {
'https://example.com/test-ext/v1',
'https://example.com/test-ext/v2',
}
assert len(actual_extensions_list) == 2
assert actual_extensions == expected_extensions

@pytest.mark.asyncio
async def test_get_card_with_extended_card_support_with_extensions(
self, mock_httpx_client: AsyncMock
):
"""Test get_card with extensions passed to get_card call when extended card support is enabled.
Tests that the extensions are added to the RPC request."""
extensions = [
'https://example.com/test-ext/v1',
'https://example.com/test-ext/v2',
]
agent_card = AGENT_CARD.model_copy(
update={'supports_authenticated_extended_card': True}
)
client = JsonRpcTransport(
httpx_client=mock_httpx_client,
agent_card=agent_card,
extensions=extensions,
)

rpc_response = {
'id': '123',
'jsonrpc': '2.0',
'result': AGENT_CARD_EXTENDED.model_dump(mode='json'),
}
with patch.object(
client, '_send_request', new_callable=AsyncMock
) as mock_send_request:
mock_send_request.return_value = rpc_response
await client.get_card(extensions=extensions)

mock_send_request.assert_called_once()
_, mock_kwargs = mock_send_request.call_args[0]

headers = mock_kwargs.get('headers', {})
assert HTTP_EXTENSION_HEADER in headers
header_value = headers[HTTP_EXTENSION_HEADER]
actual_extensions_list = [e.strip() for e in header_value.split(',')]
actual_extensions = set(actual_extensions_list)

expected_extensions = {
'https://example.com/test-ext/v1',
'https://example.com/test-ext/v2',
}
assert len(actual_extensions_list) == 2
assert actual_extensions == expected_extensions
111 changes: 110 additions & 1 deletion tests/client/transports/test_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from a2a.client import create_text_message_object
from a2a.client.transports.rest import RestTransport
from a2a.extensions.common import HTTP_EXTENSION_HEADER
from a2a.types import AgentCard, MessageSendParams, Role
from a2a.types import (
AgentCapabilities,
AgentCard,
AgentSkill,
MessageSendParams,
Role,
)


@pytest.fixture
Expand Down Expand Up @@ -119,3 +125,106 @@ async def test_send_message_streaming_with_new_extensions(
assert (
headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2'
)

@pytest.mark.asyncio
async def test_get_card_no_card_provided_with_extensions(
self, mock_httpx_client: AsyncMock
):
"""Test get_card with extensions set in Client when no card is initially provided.
Tests that the extensions are added to the HTTP GET request."""
extensions = [
'https://example.com/test-ext/v1',
'https://example.com/test-ext/v2',
]
client = RestTransport(
httpx_client=mock_httpx_client,
url='http://agent.example.com/api',
extensions=extensions,
)

mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = {
'name': 'Test Agent',
'description': 'Test Agent Description',
'url': 'http://agent.example.com/api',
'version': '1.0.0',
'default_input_modes': ['text'],
'default_output_modes': ['text'],
'capabilities': AgentCapabilities().model_dump(),
'skills': [],
}
mock_httpx_client.get.return_value = mock_response

await client.get_card()

mock_httpx_client.get.assert_called_once()
_, mock_kwargs = mock_httpx_client.get.call_args

headers = mock_kwargs.get('headers', {})
assert HTTP_EXTENSION_HEADER in headers
header_value = headers[HTTP_EXTENSION_HEADER]
actual_extensions_list = [e.strip() for e in header_value.split(',')]
actual_extensions = set(actual_extensions_list)

expected_extensions = {
'https://example.com/test-ext/v1',
'https://example.com/test-ext/v2',
}
assert len(actual_extensions_list) == 2
assert actual_extensions == expected_extensions

@pytest.mark.asyncio
async def test_get_card_with_extended_card_support_with_extensions(
self, mock_httpx_client: AsyncMock
):
"""Test get_card with extensions passed to get_card call when extended card support is enabled.
Tests that the extensions are added to the GET request."""
extensions = [
'https://example.com/test-ext/v1',
'https://example.com/test-ext/v2',
]
agent_card = AgentCard(
name='Test Agent',
description='Test Agent Description',
url='http://agent.example.com/api',
version='1.0.0',
default_input_modes=['text'],
default_output_modes=['text'],
capabilities=AgentCapabilities(),
skills=[],
supports_authenticated_extended_card=True,
)
client = RestTransport(
httpx_client=mock_httpx_client,
agent_card=agent_card,
)

mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = agent_card.model_dump(mode='json')
mock_httpx_client.send.return_value = mock_response

with patch.object(
client, '_send_get_request', new_callable=AsyncMock
) as mock_send_get_request:
mock_send_get_request.return_value = agent_card.model_dump(
mode='json'
)
await client.get_card(extensions=extensions)

mock_send_get_request.assert_called_once()
_, _, mock_kwargs = mock_send_get_request.call_args[0]

headers = mock_kwargs.get('headers', {})
assert HTTP_EXTENSION_HEADER in headers
header_value = headers[HTTP_EXTENSION_HEADER]
actual_extensions_list = [e.strip() for e in header_value.split(',')]
actual_extensions = set(actual_extensions_list)

expected_extensions = {
'https://example.com/test-ext/v1',
'https://example.com/test-ext/v2',
}
assert len(actual_extensions_list) == 2
assert actual_extensions == expected_extensions
Loading