Skip to content

Commit b6419a0

Browse files
committed
Allow configuring headers and security checks.
1 parent 7932c20 commit b6419a0

18 files changed

+343
-24
lines changed

packages/a2a_dart/lib/src/client/http_transport.dart

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,32 @@ class HttpTransport implements Transport {
2525
/// The logger to use for logging.
2626
final Logger? log;
2727

28+
/// Default headers to include in all requests.
29+
final Map<String, String> defaultHeaders;
30+
2831
/// Creates an [HttpTransport].
2932
///
3033
/// The [url] is the base URL of the A2A server.
3134
/// The [client] is an optional HTTP client to use for requests. If not
3235
/// provided, a new one will be created.
3336
/// The [log] is an optional logger.
34-
HttpTransport({required this.url, http.Client? client, this.log})
35-
: client = client ?? http.Client();
37+
/// The [defaultHeaders] are optional headers to include in all requests.
38+
HttpTransport({
39+
required this.url,
40+
http.Client? client,
41+
this.log,
42+
this.defaultHeaders = const {},
43+
}) : client = client ?? http.Client();
3644

3745
@override
3846
Future<Map<String, Object?>> get(
3947
String path, {
4048
Map<String, String> headers = const {},
4149
}) async {
4250
final uri = Uri.parse('$url$path');
43-
log?.fine('Sending GET request to $uri');
44-
final response = await client.get(uri, headers: headers);
51+
final mergedHeaders = {...defaultHeaders, ...headers};
52+
log?.fine('Sending GET request to $uri with headers: $mergedHeaders');
53+
final response = await client.get(uri, headers: mergedHeaders);
4554
log?.fine('Received response from GET $uri: ${response.body}');
4655
if (response.statusCode != 200) {
4756
throw A2AException.http(
@@ -59,9 +68,13 @@ class HttpTransport implements Transport {
5968
}) async {
6069
final uri = Uri.parse('$url$path');
6170
log?.fine('Sending POST request to $uri with body: $request');
71+
final mergedHeaders = {
72+
'Content-Type': 'application/json',
73+
...defaultHeaders,
74+
};
6275
final response = await client.post(
6376
uri,
64-
headers: {'Content-Type': 'application/json'},
77+
headers: mergedHeaders,
6578
body: jsonEncode(request),
6679
);
6780
log?.fine('Received response from POST $uri: ${response.body}');

packages/a2a_dart/lib/src/client/sse_transport.dart

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,25 @@ class SseTransport extends HttpTransport {
2727
/// The [url] is the base URL of the A2A server. An optional [client] can be
2828
/// provided for testing or to customize the HTTP client. The [log] is an
2929
/// optional logger.
30-
SseTransport({required super.url, super.client, super.log});
30+
SseTransport({
31+
required super.url,
32+
super.client,
33+
super.log,
34+
super.defaultHeaders,
35+
});
3136

3237
@override
3338
Stream<Map<String, Object?>> sendStream(Map<String, Object?> request) async* {
3439
final uri = Uri.parse('$url/rpc');
3540
final body = jsonEncode(request);
3641
log?.fine('Sending SSE request to $uri with body: $body');
42+
final mergedHeaders = {
43+
'Content-Type': 'application/json',
44+
'Accept': 'text/event-stream',
45+
...defaultHeaders,
46+
};
3747
final httpRequest = http.Request('POST', uri)
38-
..headers['Content-Type'] = 'application/json'
39-
..headers['Accept'] = 'text/event-stream'
48+
..headers.addAll(mergedHeaders)
4049
..body = body;
4150

4251
try {

packages/a2a_dart/lib/src/server/create_task_handler.dart

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ class CreateTaskHandler implements RequestHandler {
2626
@override
2727
String get method => 'create_task';
2828

29+
@override
30+
List<Map<String, List<String>>>? get securityRequirements => null;
31+
2932
@override
3033
FutureOr<HandlerResult> handle(Map<String, Object?> params) async {
3134
if (!params.containsKey('message')) {

packages/a2a_dart/lib/src/server/delete_push_config_handler.dart

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ class DeletePushConfigHandler implements RequestHandler {
1717
@override
1818
String get method => 'tasks/pushNotificationConfig/delete';
1919

20+
@override
21+
List<Map<String, List<String>>>? get securityRequirements => null;
22+
2023
@override
2124
Future<HandlerResult> handle(Map<String, Object?> params) async {
2225
final taskId = params['id'] as String?;

packages/a2a_dart/lib/src/server/get_push_config_handler.dart

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ class GetPushConfigHandler implements RequestHandler {
1818
@override
1919
String get method => 'tasks/pushNotificationConfig/get';
2020

21+
@override
22+
List<Map<String, List<String>>>? get securityRequirements => null;
23+
2124
@override
2225
Future<HandlerResult> handle(Map<String, Object?> params) async {
2326
final taskId = params['id'] as String?;

packages/a2a_dart/lib/src/server/get_task_handler.dart

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ class GetTaskHandler extends RequestHandler {
2020
@override
2121
String get method => 'tasks/get';
2222

23+
@override
24+
List<Map<String, List<String>>>? get securityRequirements => null;
25+
2326
@override
2427
FutureOr<HandlerResult> handle(Map<String, Object?> params) async {
2528
final taskId = params['id'] as String?;

packages/a2a_dart/lib/src/server/in_memory_task_manager.dart

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ class InMemoryTaskManager implements TaskManager {
1919
final _uuid = const Uuid();
2020
final _pushConfigs = <String, Map<String, PushNotificationConfig>>{};
2121

22-
@override
2322
@override
2423
Future<Task> createTask([Message? message]) async {
2524
final taskId = _uuid.v4();
@@ -35,19 +34,16 @@ class InMemoryTaskManager implements TaskManager {
3534
return task;
3635
}
3736

38-
@override
3937
@override
4038
Future<Task?> getTask(String id) async => _tasks[id];
4139

42-
@override
4340
@override
4441
Future<void> updateTask(Task task) async {
4542
_tasks[task.id] = task.copyWith(
4643
lastUpdated: DateTime.now().millisecondsSinceEpoch,
4744
);
4845
}
4946

50-
@override
5147
@override
5248
Future<Task?> cancelTask(String taskId) async {
5349
final task = _tasks[taskId];
@@ -62,7 +58,6 @@ class InMemoryTaskManager implements TaskManager {
6258
return null;
6359
}
6460

65-
@override
6661
@override
6762
Future<ListTasksResult> listTasks(ListTasksParams params) async {
6863
var tasks = _tasks.values.toList();

packages/a2a_dart/lib/src/server/io_a2a_server.dart

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import 'dart:io';
88

99
import 'package:logging/logging.dart';
1010
import 'package:shelf/shelf.dart';
11-
import 'package:shelf/shelf_io.dart' as io;
11+
import 'package:shelf/shelf_io.dart' as shelf_io;
1212
import 'package:shelf_router/shelf_router.dart';
1313

1414
import '../core/agent_card.dart';
@@ -44,6 +44,8 @@ class A2AServer {
4444
final int _requestedPort;
4545

4646
final Map<String, RequestHandler> _handlers = {};
47+
final TaskManager taskManager;
48+
final Middleware? initialMiddleware;
4749

4850
/// The public agent card for this server.
4951
///
@@ -78,12 +80,13 @@ class A2AServer {
7880
/// ```
7981
A2AServer(
8082
List<RequestHandler> handlers,
81-
TaskManager taskManager, {
83+
this.taskManager, {
8284
this.host = 'localhost',
8385
int port = 0,
8486
Logger? logger,
8587
this.agentCard,
8688
this.extendedAgentCard,
89+
this.initialMiddleware,
8790
}) : _requestedPort = port,
8891
_log = logger {
8992
for (final handler in handlers) {
@@ -128,6 +131,9 @@ class A2AServer {
128131
..get('/.well-known/agent-card.json', _handleAgentCardRequest);
129132

130133
var pipeline = const Pipeline();
134+
if (initialMiddleware != null) {
135+
pipeline = pipeline.addMiddleware(initialMiddleware!);
136+
}
131137
if (_log != null) {
132138
pipeline = pipeline.addMiddleware(
133139
logRequests(
@@ -144,7 +150,7 @@ class A2AServer {
144150
final handler = pipeline.addHandler(router.call);
145151

146152
_log?.info('Starting A2A server on $host:$_requestedPort...');
147-
_server = await io.serve(handler, host, _requestedPort);
153+
_server = await shelf_io.serve(handler, host, _requestedPort);
148154
_log?.info(
149155
'A2A server started on ${_server!.address.host}:${_server!.port}',
150156
);
@@ -169,19 +175,23 @@ class A2AServer {
169175

170176
Future<Response> _handleRpcRequest(Request request) async {
171177
_log?.info('Received request: ${request.method} ${request.requestedUri}');
172-
final body = await request.readAsString();
173-
_log?.fine('Request body: $body');
174178

175179
Object? id;
176-
Map<String, Object?> json;
177-
try {
178-
json = jsonDecode(body) as Map<String, Object?>;
179-
id = json['id'];
180-
} on FormatException {
181-
return _jsonRpcError(id: null, code: -32700, message: 'Parse error');
180+
var json = request.context['a2a_body'] as Map<String, Object?>?;
181+
182+
if (json == null) {
183+
// This should not happen if the security middleware ran first for /rpc requests
184+
final body = await request.readAsString();
185+
_log?.fine('Request body: $body');
186+
try {
187+
json = jsonDecode(body) as Map<String, Object?>;
188+
} on FormatException {
189+
return _jsonRpcError(id: null, code: -32700, message: 'Parse error');
190+
}
182191
}
183192

184193
try {
194+
id = json['id'];
185195
final method = json['method'] as String?;
186196
final params = json['params'] as Map<String, Object?>?;
187197

@@ -198,6 +208,54 @@ class A2AServer {
198208
responseCode: 404,
199209
);
200210
}
211+
// Security Check
212+
final securityRequirements = handler.securityRequirements;
213+
if (securityRequirements != null && securityRequirements.isNotEmpty) {
214+
final authContext =
215+
request.context['a2a_auth'] as Map<String, dynamic>?;
216+
217+
if (authContext == null || authContext['isAuthenticated'] != true) {
218+
return _jsonRpcError(
219+
id: id,
220+
code: -32002,
221+
message: 'Unauthorized: Missing or failed authentication.',
222+
responseCode: 401,
223+
);
224+
}
225+
226+
final authenticatedSchemes =
227+
authContext['schemes'] as Map<String, List<String>>? ?? {};
228+
var authorized = false;
229+
for (final requirement in securityRequirements) {
230+
var requirementMet = true;
231+
for (final schemeName in requirement.keys) {
232+
final requiredScopes = requirement[schemeName]!;
233+
if (!authenticatedSchemes.containsKey(schemeName)) {
234+
requirementMet = false;
235+
break;
236+
}
237+
final grantedScopes = authenticatedSchemes[schemeName]!;
238+
if (!requiredScopes.every(grantedScopes.contains)) {
239+
requirementMet = false;
240+
break;
241+
}
242+
}
243+
if (requirementMet) {
244+
authorized = true;
245+
break;
246+
}
247+
}
248+
249+
if (!authorized) {
250+
return _jsonRpcError(
251+
id: id,
252+
code: -32002,
253+
message: 'Unauthorized: Insufficient permissions for this method.',
254+
responseCode: 401,
255+
);
256+
}
257+
}
258+
201259
return _executeHandler(handler, params, id);
202260
} on Exception catch (exception, stackTrace) {
203261
_log?.severe('Unhandled server exception', exception, stackTrace);

packages/a2a_dart/lib/src/server/list_push_configs_handler.dart

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ class ListPushConfigsHandler implements RequestHandler {
1717
@override
1818
String get method => 'tasks/pushNotificationConfig/list';
1919

20+
@override
21+
List<Map<String, List<String>>>? get securityRequirements => null;
22+
2023
@override
2124
Future<HandlerResult> handle(Map<String, Object?> params) async {
2225
final taskId = params['id'] as String?;

packages/a2a_dart/lib/src/server/list_tasks_handler.dart

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ class ListTasksHandler extends RequestHandler {
2020
@override
2121
String get method => 'tasks/list';
2222

23+
@override
24+
List<Map<String, List<String>>>? get securityRequirements => null;
25+
2326
@override
2427
FutureOr<HandlerResult> handle(Map<String, Object?> params) async {
2528
final listTasksParams = ListTasksParams.fromJson(params);

0 commit comments

Comments
 (0)