Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
01b421e
fix: change "client/test_client.py" to "client/test_client_factory.py…
sokoliva Oct 23, 2025
697438f
feat: Add client-side extension support
sokoliva Oct 28, 2025
17d30a4
Merge branch 'main' into Extension-support-for-Client
sokoliva Oct 28, 2025
860f2d5
refactor: remove redundant tests for send_message without extensions …
sokoliva Oct 29, 2025
511de38
refactor: reorder parameters in JsonRpcTransport and RestTransport co…
sokoliva Oct 29, 2025
6e80123
refactor: reorder parameters in JsonRpcTransport and RestTransport co…
sokoliva Oct 29, 2025
fd5986a
Merge branch 'Extension-support-for-Client' of https://github.com/sok…
sokoliva Oct 29, 2025
31a4581
Fix Parsing Bug in _update_extension_header method
sokoliva Oct 29, 2025
5fc530e
Fix Parsing Bug in _update_extension_header method
sokoliva Oct 29, 2025
3144f43
Merge branch 'Extension-support-for-Client' of https://github.com/sok…
sokoliva Oct 29, 2025
97eec52
refactor: streamline extension header handling in JsonRpcTransport an…
sokoliva Oct 29, 2025
caba0a2
refactor: rename client_extensions to extensions in JsonRpcTransport …
sokoliva Oct 30, 2025
28b1d53
feat: move common functions for managing HTTP extension headers to ut…
sokoliva Nov 3, 2025
270d6e7
Remove extensions from grpc methog get_card
sokoliva Nov 3, 2025
a9aa9ee
feat: add support for extensions in Client and BaseClient, update tra…
sokoliva Nov 3, 2025
948d3f3
fix: correct order of extension header updates in update_extension_he…
sokoliva Nov 3, 2025
4073c0b
refactor: streamline extension handling in BaseClient and GrpcTranspo…
sokoliva Nov 4, 2025
c5cea2c
Move transport tests from tests/client to tests/client/transport. Add…
sokoliva Nov 5, 2025
6e856d5
feat: enhance GrpcTransport to manage extensions in metadata and upda…
sokoliva Nov 6, 2025
edd7982
refactor: remove unused __merge_extensions function from utils.py
sokoliva Nov 6, 2025
48ea2ae
feat: update extension handling in transports and tests, migrate util…
sokoliva Nov 12, 2025
ffc0279
Merge remote-tracking branch 'origin/main' into Extension-support-for…
sokoliva Nov 12, 2025
5b47562
fix(client): clarify the purpose of the extensions parameter in Clien…
sokoliva Nov 12, 2025
a2eeb7b
feat: enhance extension handling across client and transport layers
sokoliva Nov 13, 2025
0746541
feat: add extensions parameter documentation in ClientFactory and upd…
sokoliva Nov 13, 2025
f5443d6
Merge branch 'main' into Extension-support-for-Client
sokoliva Nov 13, 2025
1337dcf
refactor: streamline extension handling in transport classes and upda…
sokoliva Nov 14, 2025
7f4ba58
Merge remote-tracking branch 'refs/remotes/upstream/Extension-support…
sokoliva Nov 14, 2025
a97c5b3
Merge branch 'main' into Extension-support-for-Client
sokoliva Nov 14, 2025
16ee453
add integration test for extensions. Add a test case to test_common.p…
sokoliva Nov 17, 2025
674e840
Merge branch 'main' into Extension-support-for-Client
sokoliva Nov 17, 2025
7fb55d0
Merge remote-tracking branch 'refs/remotes/upstream/Extension-support…
sokoliva Nov 17, 2025
80be4bf
change test case name in tests/extensions/test_common.py
sokoliva Nov 17, 2025
4a423ef
Change the order of update_extension_header and _apply_interceptors f…
sokoliva Nov 18, 2025
9a5b1d6
Merge branch 'main' into Extension-support-for-Client
sokoliva Nov 18, 2025
125406d
Change assertion in test_client_server_integration
sokoliva Nov 18, 2025
f581c27
Merge remote-tracking branch 'refs/remotes/upstream/Extension-support…
sokoliva Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class ClientConfig:
)
"""Push notification callbacks to use for every request."""

extensions: list[str] = dataclasses.field(default_factory=list)
"""A list of extension URIs the client supports."""


UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None
# Alias for emitted events from client
Expand Down
2 changes: 2 additions & 0 deletions src/a2a/client/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _register_defaults(
TransportProtocol.jsonrpc,
lambda card, url, config, interceptors: JsonRpcTransport(
config.httpx_client or httpx.AsyncClient(),
config.extensions or None,
card,
url,
interceptors,
Expand All @@ -87,6 +88,7 @@ def _register_defaults(
TransportProtocol.http_json,
lambda card, url, config, interceptors: RestTransport(
config.httpx_client or httpx.AsyncClient(),
config.extensions or None,
card,
url,
interceptors,
Expand Down
19 changes: 19 additions & 0 deletions src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
from a2a.extensions.common import HTTP_EXTENSION_HEADER
from a2a.types import (
AgentCard,
CancelTaskRequest,
Expand Down Expand Up @@ -56,14 +57,15 @@
class JsonRpcTransport(ClientTransport):
"""A JSON-RPC transport for the A2A client."""

def __init__(
self,
httpx_client: httpx.AsyncClient,
client_extensions: list[str] | None = None,
agent_card: AgentCard | None = None,
url: str | None = None,
interceptors: list[ClientCallInterceptor] | None = None,
):
"""Initializes the JsonRpcTransport."""

Check notice on line 68 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (41-49)
if url:
self.url = url
elif agent_card:
Expand All @@ -71,26 +73,41 @@
else:
raise ValueError('Must provide either agent_card or url')

self.httpx_client = httpx_client
self.client_extensions = client_extensions
self.agent_card = agent_card
self.interceptors = interceptors or []
self._needs_extended_card = (
agent_card.supports_authenticated_extended_card
if agent_card
else True
)

def _update_extension_header(
self, http_kwargs: dict[str, Any]
) -> dict[str, Any]:
if self.client_extensions:
headers = http_kwargs.get('headers', {})
existing_extensions = headers.get(HTTP_EXTENSION_HEADER, '')
split = (
existing_extensions.split(', ') if existing_extensions else []
)
updated_extensions = list(set(self.client_extensions + split))
headers[HTTP_EXTENSION_HEADER] = ', '.join(updated_extensions)
http_kwargs['headers'] = headers
return http_kwargs

async def _apply_interceptors(
self,
method_name: str,

Check notice on line 102 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (58-84)
request_payload: dict[str, Any],
http_kwargs: dict[str, Any] | None,
context: ClientCallContext | None,
) -> tuple[dict[str, Any], dict[str, Any]]:
final_http_kwargs = http_kwargs or {}
final_request_payload = request_payload

for interceptor in self.interceptors:

Check notice on line 110 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (83-90)
(
final_request_payload,
final_http_kwargs,
Expand Down Expand Up @@ -122,6 +139,7 @@
self._get_http_args(context),
context,
)
modified_kwargs = self._update_extension_header(modified_kwargs)
response_data = await self._send_request(payload, modified_kwargs)
response = SendMessageResponse.model_validate(response_data)
if isinstance(response.root, JSONRPCErrorResponse):
Expand All @@ -147,6 +165,7 @@
context,
)

modified_kwargs = self._update_extension_header(modified_kwargs)
modified_kwargs.setdefault(
'timeout', self.httpx_client.timeout.as_dict().get('read', None)
)
Expand Down
18 changes: 18 additions & 0 deletions src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,40 @@
from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
from a2a.extensions.common import HTTP_EXTENSION_HEADER
from a2a.grpc import a2a_pb2
from a2a.types import (
AgentCard,
GetTaskPushNotificationConfigParams,
Message,
MessageSendParams,
Task,
TaskArtifactUpdateEvent,
TaskIdParams,
TaskPushNotificationConfig,
TaskQueryParams,
TaskStatusUpdateEvent,
)
from a2a.utils import proto_utils
from a2a.utils.telemetry import SpanKind, trace_class


logger = logging.getLogger(__name__)


@trace_class(kind=SpanKind.CLIENT)
class RestTransport(ClientTransport):

Check notice on line 38 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/grpc.py (20-40)
"""A REST transport for the A2A client."""

def __init__(
self,
httpx_client: httpx.AsyncClient,
client_extensions: list[str] | None = None,
agent_card: AgentCard | None = None,
url: str | None = None,
interceptors: list[ClientCallInterceptor] | None = None,
):
"""Initializes the RestTransport."""

Check notice on line 49 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (60-68)
if url:
self.url = url
elif agent_card:
Expand All @@ -53,24 +55,39 @@
raise ValueError('Must provide either agent_card or url')
if self.url.endswith('/'):
self.url = self.url[:-1]
self.httpx_client = httpx_client
self.client_extensions = client_extensions
self.agent_card = agent_card
self.interceptors = interceptors or []
self._needs_extended_card = (
agent_card.supports_authenticated_extended_card
if agent_card
else True
)

def _update_extension_header(
self, http_kwargs: dict[str, Any]
) -> dict[str, Any]:
if self.client_extensions:
headers = http_kwargs.get('headers', {})
existing_extensions = headers.get(HTTP_EXTENSION_HEADER, '')
split = (
existing_extensions.split(', ') if existing_extensions else []
)
updated_extensions = list(set(self.client_extensions + split))
headers[HTTP_EXTENSION_HEADER] = ', '.join(updated_extensions)
http_kwargs['headers'] = headers
return http_kwargs

async def _apply_interceptors(
self,
request_payload: dict[str, Any],

Check notice on line 84 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (76-102)
http_kwargs: dict[str, Any] | None,
context: ClientCallContext | None,
) -> tuple[dict[str, Any], dict[str, Any]]:
final_http_kwargs = http_kwargs or {}
final_request_payload = request_payload
# TODO: Implement interceptors for other transports

Check notice on line 90 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (102-110)
return final_request_payload, final_http_kwargs

def _get_http_args(
Expand Down Expand Up @@ -98,6 +115,7 @@
self._get_http_args(context),
context,
)
modified_kwargs = self._update_extension_header(modified_kwargs)
return payload, modified_kwargs

async def send_message(
Expand Down
2 changes: 1 addition & 1 deletion tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

1. Run the tests
```bash
uv run pytest -v -s client/test_client.py
uv run pytest -v -s client/test_client_factory.py
```

In case of failures, you can cleanup the cache:
Expand Down
179 changes: 179 additions & 0 deletions tests/client/test_jsonrpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
create_text_message_object,
)
from a2a.client.transports.jsonrpc import JsonRpcTransport
from a2a.extensions.common import HTTP_EXTENSION_HEADER
from a2a.types import (
AgentCapabilities,
AgentCard,
Expand Down Expand Up @@ -785,3 +786,181 @@ async def test_close(self, mock_httpx_client: AsyncMock):
)
await client.close()
mock_httpx_client.aclose.assert_called_once()


class TestJsonRpcTransportExtensions:
def test_update_extension_header_no_initial_headers(
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
):
extensions = ['test_extension_1', 'test_extension_2']
client = JsonRpcTransport(
mock_httpx_client, extensions, mock_agent_card
)
http_kwargs = {}
result_kwargs = client._update_extension_header(http_kwargs)
actual_extensions = set(
result_kwargs['headers'][HTTP_EXTENSION_HEADER].split(', ')
)
expected_extensions = {'test_extension_1', 'test_extension_2'}
assert actual_extensions == expected_extensions

def test_update_extension_header_with_existing_other_headers(
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
):
extensions = ['test_extension_1']
client = JsonRpcTransport(
mock_httpx_client, extensions, mock_agent_card
)
http_kwargs = {'headers': {'X_Other': 'Test'}}
result_kwargs = client._update_extension_header(http_kwargs)
assert (
result_kwargs['headers'][HTTP_EXTENSION_HEADER]
== 'test_extension_1'
)
assert result_kwargs['headers']['X_Other'] == 'Test'

def test_update_extension_header_merge_with_existing_extensions(
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
):
extensions = ['test_extension_1', 'test_extension_2']
client = JsonRpcTransport(
mock_httpx_client, extensions, mock_agent_card
)
http_kwargs = {
'headers': {
HTTP_EXTENSION_HEADER: 'test_extension_2, test_extension_3'
}
}
result_kwargs = client._update_extension_header(http_kwargs)
actual_extensions_list = result_kwargs['headers'][
HTTP_EXTENSION_HEADER
].split(', ')
actual_extensions = set(actual_extensions_list)
expected_extensions = {
'test_extension_1',
'test_extension_2',
'test_extension_3',
}
assert len(actual_extensions_list) == 3
assert actual_extensions == expected_extensions

def test_update_extension_header_no_client_extensions(
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
):
client = JsonRpcTransport(mock_httpx_client, None, mock_agent_card)
http_kwargs = {'headers': {'X_Other': 'Test'}}
result_kwargs = client._update_extension_header(http_kwargs)
assert HTTP_EXTENSION_HEADER not in result_kwargs['headers']
assert result_kwargs['headers']['X_Other'] == 'Test'

def test_update_extension_header_empty_client_extensions(
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
):
client = JsonRpcTransport(mock_httpx_client, [], mock_agent_card)
http_kwargs = {'headers': {'X_Other': 'Test'}}
result_kwargs = client._update_extension_header(http_kwargs)
assert HTTP_EXTENSION_HEADER not in result_kwargs['headers']
assert result_kwargs['headers']['X_Other'] == 'Test'

@pytest.mark.asyncio
async def test_send_message_with_extensions(
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
):
"""Test that send_message adds extension headers when client_extensions are provided."""
extensions = ['test_extension_1', 'test_extension_2']
client = JsonRpcTransport(
httpx_client=mock_httpx_client,
client_extensions=extensions,
agent_card=mock_agent_card,
)
params = MessageSendParams(
message=create_text_message_object(content='Hello')
)
success_response = create_text_message_object(
role=Role.agent, content='Hi there!'
)
rpc_response = SendMessageSuccessResponse(
id='123', jsonrpc='2.0', result=success_response
)
# Mock the response from httpx_client.post
mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = rpc_response.model_dump(mode='json')
mock_httpx_client.post.return_value = mock_response

await client.send_message(request=params)

mock_httpx_client.post.assert_called_once()
_, mock_kwargs = mock_httpx_client.post.call_args
headers = mock_kwargs.get('headers', {})
assert HTTP_EXTENSION_HEADER in headers
actual_extensions = set(headers[HTTP_EXTENSION_HEADER].split(', '))
expected_extensions = {'test_extension_1', 'test_extension_2'}
assert actual_extensions == expected_extensions

@pytest.mark.asyncio
async def test_send_message_no_extensions(
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
):
"""Test that send_message does not add extension headers when client_extensions is None."""
client = JsonRpcTransport(
httpx_client=mock_httpx_client,
client_extensions=None,
agent_card=mock_agent_card,
)
params = MessageSendParams(
message=create_text_message_object(content='Hello')
)
success_response = create_text_message_object(
role=Role.agent, content='Hi there!'
)
rpc_response = SendMessageSuccessResponse(
id='123', jsonrpc='2.0', result=success_response
)
# Mock the response from httpx_client.post
mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = rpc_response.model_dump(mode='json')
mock_httpx_client.post.return_value = mock_response

await client.send_message(request=params)

mock_httpx_client.post.assert_called_once()
_, mock_kwargs = mock_httpx_client.post.call_args
headers = mock_kwargs.get('headers', {})
assert HTTP_EXTENSION_HEADER not in headers

@pytest.mark.asyncio
@patch('a2a.client.transports.jsonrpc.aconnect_sse')
async def test_send_message_streaming_with_extensions(
self,
mock_aconnect_sse: AsyncMock,
mock_httpx_client: AsyncMock,
mock_agent_card: MagicMock,
):
"""Test X-A2A-Extensions header in send_message_streaming."""
extensions = ['test_extension']
client = JsonRpcTransport(
httpx_client=mock_httpx_client,
client_extensions=extensions,
agent_card=mock_agent_card,
)
params = MessageSendParams(
message=create_text_message_object(content='Hello stream')
)

mock_event_source = AsyncMock(spec=EventSource)
mock_event_source.aiter_sse.return_value = async_iterable_from_list([])
mock_aconnect_sse.return_value.__aenter__.return_value = (
mock_event_source
)

async for _ in client.send_message_streaming(request=params):
pass

mock_aconnect_sse.assert_called_once()
_, kwargs = mock_aconnect_sse.call_args

headers = kwargs.get('headers', {})
assert HTTP_EXTENSION_HEADER in headers
assert headers[HTTP_EXTENSION_HEADER] == 'test_extension'
Loading
Loading