Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions all.sln
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ServiceInvocationDemo.Servi
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Common", "examples\Hosting\Aspire\ServiceInvocationDemo\Common\Common.csproj", "{6CD90C22-0F79-4D61-8DCE-5BE22C1304C4}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Dapr.AI.A2A", "src\Dapr.AI.A2A\Dapr.AI.A2A.csproj", "{AE9804A8-906C-4C3B-B2A8-41F4D3269C19}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -509,6 +511,10 @@ Global
{6CD90C22-0F79-4D61-8DCE-5BE22C1304C4}.Debug|Any CPU.Build.0 = Debug|Any CPU
{6CD90C22-0F79-4D61-8DCE-5BE22C1304C4}.Release|Any CPU.ActiveCfg = Release|Any CPU
{6CD90C22-0F79-4D61-8DCE-5BE22C1304C4}.Release|Any CPU.Build.0 = Release|Any CPU
{AE9804A8-906C-4C3B-B2A8-41F4D3269C19}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{AE9804A8-906C-4C3B-B2A8-41F4D3269C19}.Debug|Any CPU.Build.0 = Debug|Any CPU
{AE9804A8-906C-4C3B-B2A8-41F4D3269C19}.Release|Any CPU.ActiveCfg = Release|Any CPU
{AE9804A8-906C-4C3B-B2A8-41F4D3269C19}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -602,6 +608,7 @@ Global
{97A47B0B-9D3B-4CF0-A62C-650F2F211A59} = {55E08C7F-81C8-4D0B-AB18-87C89B261477}
{5BB15C36-BAF7-44F6-BF85-C533B8B47862} = {55E08C7F-81C8-4D0B-AB18-87C89B261477}
{6CD90C22-0F79-4D61-8DCE-5BE22C1304C4} = {55E08C7F-81C8-4D0B-AB18-87C89B261477}
{AE9804A8-906C-4C3B-B2A8-41F4D3269C19} = {27C5D71D-0721-4221-9286-B94AB07B58CF}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {65220BF2-EAE1-4CB2-AA58-EBE80768CB40}
Expand Down
37 changes: 20 additions & 17 deletions src/Dapr.AI.A2A/DaprTaskStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ public class DaprTaskStore : ITaskStore
/// <param name="stateStoreName">The name of the state store component to use</param>
public DaprTaskStore(DaprClient daprClient, string stateStoreName = "statestore")
{
_daprClient = daprClient ?? throw new ArgumentNullException(nameof(daprClient));
ArgumentNullException.ThrowIfNull(daprClient, nameof(daprClient));
_daprClient = daprClient;
_stateStoreName = stateStoreName;
}

Expand All @@ -40,10 +41,10 @@ public DaprTaskStore(DaprClient daprClient, string stateStoreName = "statestore"
/// <returns>The task if found, null otherwise.</returns>
public async Task<AgentTask?> GetTaskAsync(string taskId, CancellationToken cancellationToken = default)
{
if (taskId == null) throw new ArgumentNullException(nameof(taskId));
ArgumentNullException.ThrowIfNull(taskId, nameof(taskId));

// Retrieve the AgentTask from Dapr state store with strong consistency to get the latest data
AgentTask? task = await _daprClient.GetStateAsync<AgentTask>(
var task = await _daprClient.GetStateAsync<AgentTask>(
_stateStoreName,
key: taskId,
consistencyMode: ConsistencyMode.Strong,
Expand All @@ -60,7 +61,7 @@ public DaprTaskStore(DaprClient daprClient, string stateStoreName = "statestore"
/// <returns>A task representing the operation.</returns>
public async Task SetTaskAsync(AgentTask task, CancellationToken cancellationToken = default)
{
if (task == null) throw new ArgumentNullException(nameof(task));
ArgumentNullException.ThrowIfNull(task, nameof(task));
// The task.Id will be used as the key. We save the entire AgentTask object.
// Use strong consistency on write; concurrency defaults to last-write-wins for new entries.
await _daprClient.SaveStateAsync(
Expand All @@ -84,7 +85,7 @@ await _daprClient.SaveStateAsync(
/// <returns>The updated task status.</returns>
public async Task<AgentTaskStatus> UpdateStatusAsync(string taskId, TaskState status, Message? message = null, CancellationToken cancellationToken = default)
{
if (taskId == null) throw new ArgumentNullException(nameof(taskId));
ArgumentNullException.ThrowIfNull(taskId, nameof(taskId));
// Fetch state with its ETag for concurrency control.
// We use strong consistency to get the latest state and ETag.
var (existingTask, etag) = await _daprClient.GetStateAndETagAsync<AgentTask>(
Expand Down Expand Up @@ -141,9 +142,9 @@ public async Task<AgentTaskStatus> UpdateStatusAsync(string taskId, TaskState st
/// <returns>The push notification configuration if found, null otherwise.</returns>
public async Task<TaskPushNotificationConfig?> GetPushNotificationAsync(string taskId, string notificationConfigId, CancellationToken cancellationToken = default)
{
if (string.IsNullOrWhiteSpace(taskId)) throw new ArgumentNullException(nameof(taskId));
if (string.IsNullOrWhiteSpace(notificationConfigId)) throw new ArgumentNullException(nameof(notificationConfigId));

ArgumentException.ThrowIfNullOrWhiteSpace(taskId, nameof(taskId));
ArgumentException.ThrowIfNullOrWhiteSpace(notificationConfigId, nameof(notificationConfigId));
return await _daprClient.GetStateAsync<TaskPushNotificationConfig>(
_stateStoreName,
key: PushCfgKey(taskId, notificationConfigId),
Expand All @@ -160,7 +161,7 @@ public async Task<AgentTaskStatus> UpdateStatusAsync(string taskId, TaskState st
/// <returns>A task representing the operation.</returns>
public async Task SetPushNotificationConfigAsync(TaskPushNotificationConfig pushNotificationConfig, CancellationToken cancellationToken = default)
{
if (pushNotificationConfig is null) throw new ArgumentNullException(nameof(pushNotificationConfig));
ArgumentNullException.ThrowIfNull(pushNotificationConfig, nameof(pushNotificationConfig));

// Adjust these property names if your model differs:
var taskId = pushNotificationConfig.TaskId ?? throw new ArgumentException("Config.TaskId is required.");
Expand All @@ -185,7 +186,7 @@ await _daprClient.SaveStateAsync(
metadata: null,
cancellationToken: cancellationToken);

var list = (index ?? Array.Empty<string>()).ToList();
var list = index.ToList();
if (!list.Contains(configId, StringComparer.Ordinal))
list.Add(configId);

Expand Down Expand Up @@ -213,34 +214,36 @@ await _daprClient.SaveStateAsync(
/// <returns>The push notification configuration if found, null otherwise.</returns>
public async Task<IEnumerable<TaskPushNotificationConfig>> GetPushNotificationsAsync(string taskId, CancellationToken cancellationToken = default)
{
if (string.IsNullOrWhiteSpace(taskId)) throw new ArgumentNullException(nameof(taskId));
ArgumentNullException.ThrowIfNull(taskId, nameof(taskId));

var ids = await _daprClient.GetStateAsync<string[]>(
_stateStoreName,
key: PushCfgIndexKey(taskId),
consistencyMode: ConsistencyMode.Strong,
metadata: null,
cancellationToken: cancellationToken) ?? Array.Empty<string>();
cancellationToken: cancellationToken);

if (ids.Length == 0) return Array.Empty<TaskPushNotificationConfig>();
if (ids.Length == 0)
return [];

const int maxParallel = 8;
using var gate = new SemaphoreSlim(maxParallel);
var bag = new ConcurrentBag<TaskPushNotificationConfig>();

await Task.WhenAll(ids.Select(async id =>
{
using var gate = new SemaphoreSlim(maxParallel);
await gate.WaitAsync(cancellationToken);
try
{
var cfg = await _daprClient.GetStateAsync<TaskPushNotificationConfig>(
var cfg = await _daprClient.GetStateAsync<TaskPushNotificationConfig?>(
_stateStoreName,
key: PushCfgKey(taskId, id),
consistencyMode: ConsistencyMode.Strong,
metadata: null,
cancellationToken: cancellationToken);

if (cfg is not null) bag.Add(cfg);
if (cfg is not null)
bag.Add(cfg);
}
finally
{
Expand All @@ -250,4 +253,4 @@ await Task.WhenAll(ids.Select(async id =>

return bag.ToArray();
}
}
}