diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 16a69cdf0..e6cb3b9bb 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -596,11 +596,22 @@ public async Task HandlesIProgressParameter() McpClientTool progressTool = tools.First(t => t.Name == nameof(EchoTool.SendsProgressNotifications)); + TaskCompletionSource tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + int remainingNotifications = 10; + ConcurrentQueue notifications = new(); await using (client.RegisterNotificationHandler(NotificationMethods.ProgressNotification, (notification, cancellationToken) => { - ProgressNotification pn = JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.DefaultOptions)!; - notifications.Enqueue(pn); + if (JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.DefaultOptions) is { } pn && + pn.ProgressToken == new ProgressToken("abc123")) + { + notifications.Enqueue(pn); + if (Interlocked.Decrement(ref remainingNotifications) == 0) + { + tcs.SetResult(); + } + } + return default; })) { @@ -613,8 +624,8 @@ public async Task HandlesIProgressParameter() }, cancellationToken: TestContext.Current.CancellationToken); + await tcs.Task; Assert.Contains("done", JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions)); - SpinWait.SpinUntil(() => notifications.Count == 10, TimeSpan.FromSeconds(10)); } ProgressNotification[] array = notifications.OrderBy(n => n.Progress.Progress).ToArray();