diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RouteBuilderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RouteBuilderTests.cs new file mode 100644 index 0000000000..a734b82b66 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RouteBuilderTests.cs @@ -0,0 +1,546 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using FluentAssertions; +using Microsoft.Agents.AI.Workflows.Execution; + +namespace Microsoft.Agents.AI.Workflows.UnitTests; + +public sealed class RouteBuilderTests +{ + public enum HandlerOverload + { + SyncWithCancellation = 0, + SyncWithoutCancellation = 1, + AsyncWithCancellation = 2, + AsyncWithoutCancellation = 3, + } + + private sealed record TestPayload(string Value); + + private sealed class HandlerInvocation + { + public object? Message { get; private set; } + + public IWorkflowContext? Context { get; private set; } + + public CancellationToken CancellationToken { get; private set; } + + public int InvocationCount { get; private set; } + + public void Capture(object? message, IWorkflowContext context, CancellationToken cancellationToken = default) + { + this.Message = message; + this.Context = context; + this.CancellationToken = cancellationToken; + this.InvocationCount++; + } + } + + private sealed class TestExternalRequestContext : IExternalRequestContext, IExternalRequestSink + { + public List RegisteredPorts { get; } = []; + + public List PostedRequests { get; } = []; + + public IExternalRequestSink RegisterPort(RequestPort port) + { + this.RegisteredPorts.Add(port); + return this; + } + + public ValueTask PostAsync(ExternalRequest request) + { + this.PostedRequests.Add(request); + return default; + } + } + + [Theory] + [InlineData(HandlerOverload.SyncWithCancellation)] + [InlineData(HandlerOverload.SyncWithoutCancellation)] + [InlineData(HandlerOverload.AsyncWithCancellation)] + [InlineData(HandlerOverload.AsyncWithoutCancellation)] + public async Task AddHandler_VoidOverloads_RouteExpectedMessageAsync(HandlerOverload overload) + { + // Arrange + RouteBuilder routeBuilder = new(null); + HandlerInvocation invocation = new(); + CancellationToken cancellationToken = new CancellationTokenSource().Token; + RegisterVoidHandler(routeBuilder, invocation, overload); + MessageRouter router = routeBuilder.Build(); + TestWorkflowContext context = new("executor"); + + // Act + CallResult? result = await router.RouteMessageAsync("hello", context, cancellationToken: cancellationToken); + + // Assert + result.Should().NotBeNull(); + result!.IsSuccess.Should().BeTrue(); + result.IsVoid.Should().BeTrue(); + result.Result.Should().BeNull(); + invocation.InvocationCount.Should().Be(1); + invocation.Message.Should().Be("hello"); + invocation.Context.Should().BeSameAs(context); + + if (UsesCancellationToken(overload)) + { + invocation.CancellationToken.Should().Be(cancellationToken); + } + } + + [Theory] + [InlineData(HandlerOverload.SyncWithCancellation)] + [InlineData(HandlerOverload.SyncWithoutCancellation)] + [InlineData(HandlerOverload.AsyncWithCancellation)] + [InlineData(HandlerOverload.AsyncWithoutCancellation)] + public async Task AddHandler_ResultOverloads_RouteExpectedMessageAsync(HandlerOverload overload) + { + // Arrange + RouteBuilder routeBuilder = new(null); + HandlerInvocation invocation = new(); + CancellationToken cancellationToken = new CancellationTokenSource().Token; + RegisterResultHandler(routeBuilder, invocation, overload); + MessageRouter router = routeBuilder.Build(); + TestWorkflowContext context = new("executor"); + + // Act + CallResult? result = await router.RouteMessageAsync("hello", context, cancellationToken: cancellationToken); + + // Assert + result.Should().NotBeNull(); + result!.IsSuccess.Should().BeTrue(); + result.IsVoid.Should().BeFalse(); + result.Result.Should().Be("HELLO"); + router.DefaultOutputTypes.Should().Contain(typeof(string)); + invocation.InvocationCount.Should().Be(1); + invocation.Message.Should().Be("hello"); + invocation.Context.Should().BeSameAs(context); + + if (UsesCancellationToken(overload)) + { + invocation.CancellationToken.Should().Be(cancellationToken); + } + } + + [Theory] + [InlineData(HandlerOverload.SyncWithCancellation)] + [InlineData(HandlerOverload.SyncWithoutCancellation)] + [InlineData(HandlerOverload.AsyncWithCancellation)] + [InlineData(HandlerOverload.AsyncWithoutCancellation)] + public async Task AddCatchAll_VoidOverloads_RouteUnexpectedMessageAsync(HandlerOverload overload) + { + // Arrange + RouteBuilder routeBuilder = new(null); + HandlerInvocation invocation = new(); + CancellationToken cancellationToken = new CancellationTokenSource().Token; + TestPayload payload = new("hello"); + RegisterVoidCatchAll(routeBuilder, invocation, overload); + MessageRouter router = routeBuilder.Build(); + TestWorkflowContext context = new("executor"); + + // Act + CallResult? result = await router.RouteMessageAsync(payload, context, cancellationToken: cancellationToken); + + // Assert + result.Should().NotBeNull(); + result!.IsSuccess.Should().BeTrue(); + result.IsVoid.Should().BeTrue(); + result.Result.Should().BeNull(); + invocation.InvocationCount.Should().Be(1); + invocation.Message.Should().BeEquivalentTo(new PortableValue(payload)); + invocation.Context.Should().BeSameAs(context); + + if (UsesCancellationToken(overload)) + { + invocation.CancellationToken.Should().Be(cancellationToken); + } + } + + [Theory] + [InlineData(HandlerOverload.SyncWithCancellation)] + [InlineData(HandlerOverload.SyncWithoutCancellation)] + [InlineData(HandlerOverload.AsyncWithCancellation)] + [InlineData(HandlerOverload.AsyncWithoutCancellation)] + public async Task AddCatchAll_ResultOverloads_RouteUnexpectedMessageAsync(HandlerOverload overload) + { + // Arrange + RouteBuilder routeBuilder = new(null); + HandlerInvocation invocation = new(); + CancellationToken cancellationToken = new CancellationTokenSource().Token; + TestPayload payload = new("hello"); + RegisterResultCatchAll(routeBuilder, invocation, overload); + MessageRouter router = routeBuilder.Build(); + TestWorkflowContext context = new("executor"); + + // Act + CallResult? result = await router.RouteMessageAsync(payload, context, cancellationToken: cancellationToken); + + // Assert + result.Should().NotBeNull(); + result!.IsSuccess.Should().BeTrue(); + result.IsVoid.Should().BeFalse(); + result.Result.Should().Be("HELLO"); + invocation.InvocationCount.Should().Be(1); + invocation.Message.Should().BeEquivalentTo(new PortableValue(payload)); + invocation.Context.Should().BeSameAs(context); + + if (UsesCancellationToken(overload)) + { + invocation.CancellationToken.Should().Be(cancellationToken); + } + } + + [Fact] + public async Task AddHandlerUntyped_VoidAndResultOverloads_RouteExpectedMessageAsync() + { + // Arrange + RouteBuilder routeBuilder = new(null); + HandlerInvocation voidInvocation = new(); + HandlerInvocation resultInvocation = new(); + CancellationToken cancellationToken = new CancellationTokenSource().Token; + routeBuilder.AddHandlerUntyped(typeof(string), (message, context, token) => + { + voidInvocation.Capture(message, context, token); + return default; + }); + routeBuilder.AddHandlerUntyped(typeof(int), (message, context, token) => + { + resultInvocation.Capture(message, context, token); + return new((int)message + 1); + }); + MessageRouter router = routeBuilder.Build(); + TestWorkflowContext context = new("executor"); + + // Act + CallResult? voidResult = await router.RouteMessageAsync("hello", context, cancellationToken: cancellationToken); + CallResult? typedResult = await router.RouteMessageAsync(41, context, cancellationToken: cancellationToken); + + // Assert + voidResult.Should().NotBeNull(); + voidResult!.IsVoid.Should().BeTrue(); + voidInvocation.Message.Should().Be("hello"); + voidInvocation.Context.Should().BeSameAs(context); + voidInvocation.CancellationToken.Should().Be(cancellationToken); + + typedResult.Should().NotBeNull(); + typedResult!.Result.Should().Be(42); + router.DefaultOutputTypes.Should().Contain(typeof(int)); + resultInvocation.Message.Should().Be(41); + resultInvocation.Context.Should().BeSameAs(context); + resultInvocation.CancellationToken.Should().Be(cancellationToken); + } + + [Fact] + public void AddHandler_ForPortableValue_ThrowsInvalidOperationException() + { + // Arrange + RouteBuilder routeBuilder = new(null); + + // Act + Action act = () => routeBuilder.AddHandler((message, context) => { }); + + // Assert + act.Should().Throw() + .WithMessage("*Use AddCatchAll()*"); + } + + [Fact] + public void AddHandler_DuplicateRegistrationWithoutOverwrite_ThrowsArgumentException() + { + // Arrange + RouteBuilder routeBuilder = new(null); + routeBuilder.AddHandler((message, context) => { }); + + // Act + Action act = () => routeBuilder.AddHandler((message, context) => { }); + + // Assert + act.Should().Throw() + .WithMessage("*already registered*"); + } + + [Fact] + public void AddHandler_OverwriteWithoutExistingRegistration_ThrowsArgumentException() + { + // Arrange + RouteBuilder routeBuilder = new(null); + + // Act + Action act = () => routeBuilder.AddHandler((message, context) => { }, overwrite: true); + + // Assert + act.Should().Throw() + .WithMessage("*has not yet been registered*"); + } + + [Fact] + public async Task AddHandler_OverwriteExistingRegistration_RoutesUpdatedHandlerAsync() + { + // Arrange + RouteBuilder routeBuilder = new(null); + routeBuilder.AddHandler((message, context) => context.SendMessageAsync("first")); + routeBuilder.AddHandler((message, context) => context.SendMessageAsync("second"), overwrite: true); + MessageRouter router = routeBuilder.Build(); + TestWorkflowContext context = new("executor"); + + // Act + _ = await router.RouteMessageAsync("hello", context); + + // Assert + context.SentMessages.Should().ContainSingle().Which.Should().Be("second"); + } + + [Fact] + public void AddCatchAll_DuplicateRegistrationWithoutOverwrite_ThrowsInvalidOperationException() + { + // Arrange + RouteBuilder routeBuilder = new(null); + routeBuilder.AddCatchAll((message, context) => { }); + + // Act + Action act = () => routeBuilder.AddCatchAll((message, context) => { }); + + // Assert + act.Should().Throw() + .WithMessage("*already registered*"); + } + + [Fact] + public async Task AddCatchAll_OverwriteExistingRegistration_RoutesUpdatedHandlerAsync() + { + // Arrange + RouteBuilder routeBuilder = new(null); + routeBuilder.AddCatchAll((message, context) => context.SendMessageAsync("first")); + routeBuilder.AddCatchAll((message, context) => context.SendMessageAsync("second"), overwrite: true); + MessageRouter router = routeBuilder.Build(); + TestWorkflowContext context = new("executor"); + + // Act + _ = await router.RouteMessageAsync(new TestPayload("hello"), context); + + // Assert + context.SentMessages.Should().ContainSingle().Which.Should().Be("second"); + } + + [Fact] + public void AddPortHandler_WithoutExternalRequestContext_ThrowsInvalidOperationException() + { + // Arrange + RouteBuilder routeBuilder = new(null); + + // Act + Action act = () => routeBuilder.AddPortHandler("port", (response, context, cancellationToken) => default, out _); + + // Assert + act.Should().Throw() + .WithMessage("*external request context is required*"); + } + + [Fact] + public async Task AddPortHandler_RoutesMatchingExternalResponseAsync() + { + // Arrange + TestExternalRequestContext externalRequestContext = new(); + RouteBuilder routeBuilder = new(externalRequestContext); + HandlerInvocation invocation = new(); + routeBuilder.AddPortHandler("port", (response, context, cancellationToken) => + { + invocation.Capture(response, context, cancellationToken); + return default; + }, out PortBinding portBinding); + await portBinding.PostRequestAsync("request", requestId: "req-1"); + MessageRouter router = routeBuilder.Build(); + TestWorkflowContext context = new("executor"); + CancellationToken cancellationToken = new CancellationTokenSource().Token; + ExternalResponse response = externalRequestContext.PostedRequests.Single().CreateResponse(42); + + // Act + CallResult? result = await router.RouteMessageAsync(response, context, cancellationToken: cancellationToken); + + // Assert + externalRequestContext.RegisteredPorts.Should().ContainSingle(port => port.Id == "port"); + externalRequestContext.PostedRequests.Should().ContainSingle(request => request.RequestId == "req-1"); + result.Should().NotBeNull(); + result!.IsSuccess.Should().BeTrue(); + result.Result.Should().BeSameAs(response); + invocation.InvocationCount.Should().Be(1); + invocation.Message.Should().Be(42); + invocation.Context.Should().BeSameAs(context); + invocation.CancellationToken.Should().Be(cancellationToken); + } + + [Fact] + public async Task AddPortHandler_UnknownPort_ReturnsExceptionResultAsync() + { + // Arrange + TestExternalRequestContext externalRequestContext = new(); + RouteBuilder routeBuilder = new(externalRequestContext); + routeBuilder.AddPortHandler("port", (response, context, cancellationToken) => default, out _); + MessageRouter router = routeBuilder.Build(); + ExternalRequest request = ExternalRequest.Create(RequestPort.Create("other"), "request", requestId: "req-1"); + + // Act + CallResult? result = await router.RouteMessageAsync(request.CreateResponse(42), new TestWorkflowContext("executor")); + + // Assert + result.Should().NotBeNull(); + result!.IsSuccess.Should().BeFalse(); + result.Exception.Should().BeOfType(); + result.Exception!.Message.Should().Contain("Unknown port"); + } + + private static void RegisterVoidHandler(RouteBuilder routeBuilder, HandlerInvocation invocation, HandlerOverload overload) + { + switch (overload) + { + case HandlerOverload.SyncWithCancellation: + routeBuilder.AddHandler((message, context, cancellationToken) => invocation.Capture(message, context, cancellationToken)); + break; + case HandlerOverload.SyncWithoutCancellation: + routeBuilder.AddHandler((message, context) => invocation.Capture(message, context)); + break; + case HandlerOverload.AsyncWithCancellation: + routeBuilder.AddHandler((message, context, cancellationToken) => + { + invocation.Capture(message, context, cancellationToken); + return default; + }); + break; + case HandlerOverload.AsyncWithoutCancellation: + routeBuilder.AddHandler((message, context) => + { + invocation.Capture(message, context); + return default; + }); + break; + default: + throw new ArgumentOutOfRangeException(nameof(overload)); + } + } + + private static void RegisterResultHandler(RouteBuilder routeBuilder, HandlerInvocation invocation, HandlerOverload overload) + { + switch (overload) + { + case HandlerOverload.SyncWithCancellation: + routeBuilder.AddHandler((message, context, cancellationToken) => + { + invocation.Capture(message, context, cancellationToken); + return NormalizeHandlerResult(message); + }); + break; + case HandlerOverload.SyncWithoutCancellation: + routeBuilder.AddHandler((message, context) => + { + invocation.Capture(message, context); + return NormalizeHandlerResult(message); + }); + break; + case HandlerOverload.AsyncWithCancellation: + Func> asyncHandlerWithCancellation = (message, context, cancellationToken) => + { + invocation.Capture(message, context, cancellationToken); + return new ValueTask(NormalizeHandlerResult(message)); + }; + routeBuilder.AddHandler(asyncHandlerWithCancellation); + break; + case HandlerOverload.AsyncWithoutCancellation: + Func> asyncHandler = (message, context) => + { + invocation.Capture(message, context); + return new ValueTask(NormalizeHandlerResult(message)); + }; + routeBuilder.AddHandler(asyncHandler); + break; + default: + throw new ArgumentOutOfRangeException(nameof(overload)); + } + } + + private static void RegisterVoidCatchAll(RouteBuilder routeBuilder, HandlerInvocation invocation, HandlerOverload overload) + { + switch (overload) + { + case HandlerOverload.SyncWithCancellation: + routeBuilder.AddCatchAll((message, context, cancellationToken) => invocation.Capture(message, context, cancellationToken)); + break; + case HandlerOverload.SyncWithoutCancellation: + routeBuilder.AddCatchAll((message, context) => invocation.Capture(message, context)); + break; + case HandlerOverload.AsyncWithCancellation: + routeBuilder.AddCatchAll((message, context, cancellationToken) => + { + invocation.Capture(message, context, cancellationToken); + return default; + }); + break; + case HandlerOverload.AsyncWithoutCancellation: + routeBuilder.AddCatchAll((message, context) => + { + invocation.Capture(message, context); + return default; + }); + break; + default: + throw new ArgumentOutOfRangeException(nameof(overload)); + } + } + + private static void RegisterResultCatchAll(RouteBuilder routeBuilder, HandlerInvocation invocation, HandlerOverload overload) + { + switch (overload) + { + case HandlerOverload.SyncWithCancellation: + routeBuilder.AddCatchAll((message, context, cancellationToken) => + { + invocation.Capture(message, context, cancellationToken); + return NormalizeCatchAllResult(message); + }); + break; + case HandlerOverload.SyncWithoutCancellation: + routeBuilder.AddCatchAll((message, context) => + { + invocation.Capture(message, context); + return NormalizeCatchAllResult(message); + }); + break; + case HandlerOverload.AsyncWithCancellation: + Func> asyncCatchAllWithCancellation = (message, context, cancellationToken) => + { + invocation.Capture(message, context, cancellationToken); + return new ValueTask(NormalizeCatchAllResult(message)); + }; + routeBuilder.AddCatchAll(asyncCatchAllWithCancellation); + break; + case HandlerOverload.AsyncWithoutCancellation: + Func> asyncCatchAll = (message, context) => + { + invocation.Capture(message, context); + return new ValueTask(NormalizeCatchAllResult(message)); + }; + routeBuilder.AddCatchAll(asyncCatchAll); + break; + default: + throw new ArgumentOutOfRangeException(nameof(overload)); + } + } + + private static bool UsesCancellationToken(HandlerOverload overload) => + overload is HandlerOverload.SyncWithCancellation or HandlerOverload.AsyncWithCancellation; + + private static string NormalizeHandlerResult(string message) => message.ToUpperInvariant(); + + private static string NormalizeCatchAllResult(PortableValue message) => GetPayloadValue(message).ToUpperInvariant(); + + private static string GetPayloadValue(PortableValue message) + { + return message.As() is TestPayload payload + ? payload.Value + : throw new InvalidOperationException("Expected catch-all message payload to deserialize as TestPayload."); + } +}