Skip to content

Commit 9a92bd2

Browse files
authored
feat: Add client-side extension support (#525)
feat(rest, jsonrpc): Add client-side extension support This commit introduces support for clients to declare the extensions they support. - Adds an `extensions` list to `ClientConfig`. - Updates `ClientFactory` to pass `client_extensions` to `JsonRpcTransport` and `RestTransport`. - Adds `_update_extension_header` method to both transports to update the `X-A2A-Extensions` header. - Modifies `send_message` and `send_message_streaming` in `JsonRpcTransport` to include the extension headers. - Modifies `_prepare_send_message` in `RestTransport` to include the extension headers. - Adds tests for the extension header logic in both JSON-RPC and REST transports, including a new test file `test_rest_client.py`. - Fixes #504 🦕
1 parent 89e9b7c commit 9a92bd2

File tree

14 files changed

+715
-57
lines changed

14 files changed

+715
-57
lines changed

src/a2a/client/base_client.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ async def send_message(
4949
*,
5050
context: ClientCallContext | None = None,
5151
request_metadata: dict[str, Any] | None = None,
52+
extensions: list[str] | None = None,
5253
) -> AsyncIterator[ClientEvent | Message]:
5354
"""Sends a message to the agent.
5455
@@ -60,6 +61,7 @@ async def send_message(
6061
request: The message to send to the agent.
6162
context: The client call context.
6263
request_metadata: Extensions Metadata attached to the request.
64+
extensions: List of extensions to be activated.
6365
6466
Yields:
6567
An async iterator of `ClientEvent` or a final `Message` response.
@@ -79,7 +81,7 @@ async def send_message(
7981

8082
if not self._config.streaming or not self._card.capabilities.streaming:
8183
response = await self._transport.send_message(
82-
params, context=context
84+
params, context=context, extensions=extensions
8385
)
8486
result = (
8587
(response, None) if isinstance(response, Task) else response
@@ -89,7 +91,9 @@ async def send_message(
8991
return
9092

9193
tracker = ClientTaskManager()
92-
stream = self._transport.send_message_streaming(params, context=context)
94+
stream = self._transport.send_message_streaming(
95+
params, context=context, extensions=extensions
96+
)
9397

9498
first_event = await anext(stream)
9599
# The response from a server may be either exactly one Message or a
@@ -126,74 +130,91 @@ async def get_task(
126130
request: TaskQueryParams,
127131
*,
128132
context: ClientCallContext | None = None,
133+
extensions: list[str] | None = None,
129134
) -> Task:
130135
"""Retrieves the current state and history of a specific task.
131136
132137
Args:
133138
request: The `TaskQueryParams` object specifying the task ID.
134139
context: The client call context.
140+
extensions: List of extensions to be activated.
135141
136142
Returns:
137143
A `Task` object representing the current state of the task.
138144
"""
139-
return await self._transport.get_task(request, context=context)
145+
return await self._transport.get_task(
146+
request, context=context, extensions=extensions
147+
)
140148

141149
async def cancel_task(
142150
self,
143151
request: TaskIdParams,
144152
*,
145153
context: ClientCallContext | None = None,
154+
extensions: list[str] | None = None,
146155
) -> Task:
147156
"""Requests the agent to cancel a specific task.
148157
149158
Args:
150159
request: The `TaskIdParams` object specifying the task ID.
151160
context: The client call context.
161+
extensions: List of extensions to be activated.
152162
153163
Returns:
154164
A `Task` object containing the updated task status.
155165
"""
156-
return await self._transport.cancel_task(request, context=context)
166+
return await self._transport.cancel_task(
167+
request, context=context, extensions=extensions
168+
)
157169

158170
async def set_task_callback(
159171
self,
160172
request: TaskPushNotificationConfig,
161173
*,
162174
context: ClientCallContext | None = None,
175+
extensions: list[str] | None = None,
163176
) -> TaskPushNotificationConfig:
164177
"""Sets or updates the push notification configuration for a specific task.
165178
166179
Args:
167180
request: The `TaskPushNotificationConfig` object with the new configuration.
168181
context: The client call context.
182+
extensions: List of extensions to be activated.
169183
170184
Returns:
171185
The created or updated `TaskPushNotificationConfig` object.
172186
"""
173-
return await self._transport.set_task_callback(request, context=context)
187+
return await self._transport.set_task_callback(
188+
request, context=context, extensions=extensions
189+
)
174190

175191
async def get_task_callback(
176192
self,
177193
request: GetTaskPushNotificationConfigParams,
178194
*,
179195
context: ClientCallContext | None = None,
196+
extensions: list[str] | None = None,
180197
) -> TaskPushNotificationConfig:
181198
"""Retrieves the push notification configuration for a specific task.
182199
183200
Args:
184201
request: The `GetTaskPushNotificationConfigParams` object specifying the task.
185202
context: The client call context.
203+
extensions: List of extensions to be activated.
186204
187205
Returns:
188206
A `TaskPushNotificationConfig` object containing the configuration.
189207
"""
190-
return await self._transport.get_task_callback(request, context=context)
208+
return await self._transport.get_task_callback(
209+
request, context=context, extensions=extensions
210+
)
191211

192212
async def resubscribe(
193213
self,
194214
request: TaskIdParams,
195215
*,
196216
context: ClientCallContext | None = None,
217+
extensions: list[str] | None = None,
197218
) -> AsyncIterator[ClientEvent]:
198219
"""Resubscribes to a task's event stream.
199220
@@ -202,6 +223,7 @@ async def resubscribe(
202223
Args:
203224
request: Parameters to identify the task to resubscribe to.
204225
context: The client call context.
226+
extensions: List of extensions to be activated.
205227
206228
Yields:
207229
An async iterator of `ClientEvent` objects.
@@ -219,12 +241,15 @@ async def resubscribe(
219241
# we should never see Message updates, despite the typing of the service
220242
# definition indicating it may be possible.
221243
async for event in self._transport.resubscribe(
222-
request, context=context
244+
request, context=context, extensions=extensions
223245
):
224246
yield await self._process_response(tracker, event)
225247

226248
async def get_card(
227-
self, *, context: ClientCallContext | None = None
249+
self,
250+
*,
251+
context: ClientCallContext | None = None,
252+
extensions: list[str] | None = None,
228253
) -> AgentCard:
229254
"""Retrieves the agent's card.
230255
@@ -233,11 +258,14 @@ async def get_card(
233258
234259
Args:
235260
context: The client call context.
261+
extensions: List of extensions to be activated.
236262
237263
Returns:
238264
The `AgentCard` for the agent.
239265
"""
240-
card = await self._transport.get_card(context=context)
266+
card = await self._transport.get_card(
267+
context=context, extensions=extensions
268+
)
241269
self._card = card
242270
return card
243271

src/a2a/client/client.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ class ClientConfig:
6767
)
6868
"""Push notification callbacks to use for every request."""
6969

70+
extensions: list[str] = dataclasses.field(default_factory=list)
71+
"""A list of extension URIs the client supports."""
72+
7073

7174
UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None
7275
# Alias for emitted events from client
@@ -111,6 +114,7 @@ async def send_message(
111114
*,
112115
context: ClientCallContext | None = None,
113116
request_metadata: dict[str, Any] | None = None,
117+
extensions: list[str] | None = None,
114118
) -> AsyncIterator[ClientEvent | Message]:
115119
"""Sends a message to the server.
116120
@@ -129,6 +133,7 @@ async def get_task(
129133
request: TaskQueryParams,
130134
*,
131135
context: ClientCallContext | None = None,
136+
extensions: list[str] | None = None,
132137
) -> Task:
133138
"""Retrieves the current state and history of a specific task."""
134139

@@ -138,6 +143,7 @@ async def cancel_task(
138143
request: TaskIdParams,
139144
*,
140145
context: ClientCallContext | None = None,
146+
extensions: list[str] | None = None,
141147
) -> Task:
142148
"""Requests the agent to cancel a specific task."""
143149

@@ -147,6 +153,7 @@ async def set_task_callback(
147153
request: TaskPushNotificationConfig,
148154
*,
149155
context: ClientCallContext | None = None,
156+
extensions: list[str] | None = None,
150157
) -> TaskPushNotificationConfig:
151158
"""Sets or updates the push notification configuration for a specific task."""
152159

@@ -156,6 +163,7 @@ async def get_task_callback(
156163
request: GetTaskPushNotificationConfigParams,
157164
*,
158165
context: ClientCallContext | None = None,
166+
extensions: list[str] | None = None,
159167
) -> TaskPushNotificationConfig:
160168
"""Retrieves the push notification configuration for a specific task."""
161169

@@ -165,14 +173,18 @@ async def resubscribe(
165173
request: TaskIdParams,
166174
*,
167175
context: ClientCallContext | None = None,
176+
extensions: list[str] | None = None,
168177
) -> AsyncIterator[ClientEvent]:
169178
"""Resubscribes to a task's event stream."""
170179
return
171180
yield
172181

173182
@abstractmethod
174183
async def get_card(
175-
self, *, context: ClientCallContext | None = None
184+
self,
185+
*,
186+
context: ClientCallContext | None = None,
187+
extensions: list[str] | None = None,
176188
) -> AgentCard:
177189
"""Retrieves the agent's card."""
178190

src/a2a/client/client_factory.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def _register_defaults(
8080
card,
8181
url,
8282
interceptors,
83+
config.extensions or None,
8384
),
8485
)
8586
if TransportProtocol.http_json in supported:
@@ -90,6 +91,7 @@ def _register_defaults(
9091
card,
9192
url,
9293
interceptors,
94+
config.extensions or None,
9395
),
9496
)
9597
if TransportProtocol.grpc in supported:
@@ -113,6 +115,7 @@ async def connect( # noqa: PLR0913
113115
relative_card_path: str | None = None,
114116
resolver_http_kwargs: dict[str, Any] | None = None,
115117
extra_transports: dict[str, TransportProducer] | None = None,
118+
extensions: list[str] | None = None,
116119
) -> Client:
117120
"""Convenience method for constructing a client.
118121
@@ -142,6 +145,7 @@ async def connect( # noqa: PLR0913
142145
A2AAgentCardResolver.get_agent_card as the http_kwargs parameter.
143146
extra_transports: Additional transport protocols to enable when
144147
constructing the client.
148+
extensions: List of extensions to be activated.
145149
146150
Returns:
147151
A `Client` object.
@@ -166,7 +170,7 @@ async def connect( # noqa: PLR0913
166170
factory = cls(client_config)
167171
for label, generator in (extra_transports or {}).items():
168172
factory.register(label, generator)
169-
return factory.create(card, consumers, interceptors)
173+
return factory.create(card, consumers, interceptors, extensions)
170174

171175
def register(self, label: str, generator: TransportProducer) -> None:
172176
"""Register a new transport producer for a given transport label."""
@@ -177,6 +181,7 @@ def create(
177181
card: AgentCard,
178182
consumers: list[Consumer] | None = None,
179183
interceptors: list[ClientCallInterceptor] | None = None,
184+
extensions: list[str] | None = None,
180185
) -> Client:
181186
"""Create a new `Client` for the provided `AgentCard`.
182187
@@ -186,6 +191,7 @@ def create(
186191
interceptors: A list of interceptors to use for each request. These
187192
are used for things like attaching credentials or http headers
188193
to all outbound requests.
194+
extensions: List of extensions to be activated.
189195
190196
Returns:
191197
A `Client` object.
@@ -226,12 +232,21 @@ def create(
226232
if consumers:
227233
all_consumers.extend(consumers)
228234

235+
all_extensions = self._config.extensions.copy()
236+
if extensions:
237+
all_extensions.extend(extensions)
238+
self._config.extensions = all_extensions
239+
229240
transport = self._registry[transport_protocol](
230241
card, transport_url, self._config, interceptors or []
231242
)
232243

233244
return BaseClient(
234-
card, self._config, transport, all_consumers, interceptors or []
245+
card,
246+
self._config,
247+
transport,
248+
all_consumers,
249+
interceptors or [],
235250
)
236251

237252

src/a2a/client/transports/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ async def send_message(
2525
request: MessageSendParams,
2626
*,
2727
context: ClientCallContext | None = None,
28+
extensions: list[str] | None = None,
2829
) -> Task | Message:
2930
"""Sends a non-streaming message request to the agent."""
3031

@@ -34,6 +35,7 @@ async def send_message_streaming(
3435
request: MessageSendParams,
3536
*,
3637
context: ClientCallContext | None = None,
38+
extensions: list[str] | None = None,
3739
) -> AsyncGenerator[
3840
Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
3941
]:
@@ -47,6 +49,7 @@ async def get_task(
4749
request: TaskQueryParams,
4850
*,
4951
context: ClientCallContext | None = None,
52+
extensions: list[str] | None = None,
5053
) -> Task:
5154
"""Retrieves the current state and history of a specific task."""
5255

@@ -56,6 +59,7 @@ async def cancel_task(
5659
request: TaskIdParams,
5760
*,
5861
context: ClientCallContext | None = None,
62+
extensions: list[str] | None = None,
5963
) -> Task:
6064
"""Requests the agent to cancel a specific task."""
6165

@@ -65,6 +69,7 @@ async def set_task_callback(
6569
request: TaskPushNotificationConfig,
6670
*,
6771
context: ClientCallContext | None = None,
72+
extensions: list[str] | None = None,
6873
) -> TaskPushNotificationConfig:
6974
"""Sets or updates the push notification configuration for a specific task."""
7075

@@ -74,6 +79,7 @@ async def get_task_callback(
7479
request: GetTaskPushNotificationConfigParams,
7580
*,
7681
context: ClientCallContext | None = None,
82+
extensions: list[str] | None = None,
7783
) -> TaskPushNotificationConfig:
7884
"""Retrieves the push notification configuration for a specific task."""
7985

@@ -83,6 +89,7 @@ async def resubscribe(
8389
request: TaskIdParams,
8490
*,
8591
context: ClientCallContext | None = None,
92+
extensions: list[str] | None = None,
8693
) -> AsyncGenerator[
8794
Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
8895
]:
@@ -95,6 +102,7 @@ async def get_card(
95102
self,
96103
*,
97104
context: ClientCallContext | None = None,
105+
extensions: list[str] | None = None,
98106
) -> AgentCard:
99107
"""Retrieves the AgentCard."""
100108

0 commit comments

Comments
 (0)