diff --git a/docs/guide/handlers/middleware.md b/docs/guide/handlers/middleware.md index 157fef73c..28ba8d1fd 100644 --- a/docs/guide/handlers/middleware.md +++ b/docs/guide/handlers/middleware.md @@ -641,3 +641,49 @@ public static class MaybeBadThing4Handler And any objects in the `OutgoingMessages` return value from the middleware method will be sent as cascaded messages. Wolverine will also apply a "maybe stop" frame from the `IHandlerContinuation` as well. +## Parameter Value Sources + +When using `WolverineParameterAttribute` subclasses (like `[Aggregate]`, `[WriteAggregate]`), you can control +where parameter values are resolved from using the `ValueSource` property or the convenience shorthand properties. + +### From an Envelope Header + +Use `FromHeader` to resolve a value from the message envelope's headers: + +```cs +public void Handle( + ProcessOrder command, + [WriteAggregate(FromHeader = "X-Tenant-Id")] TenantAggregate tenant) +{ + // tenant loaded using the value from the "X-Tenant-Id" envelope header +} +``` + +### From a Static Method + +Use `FromMethod` to resolve a value from a static method on the handler class. The method's parameters +are resolved via method injection: + +```cs +public class ProcessOrderHandler +{ + public static Guid ResolveTenantId(IMessageContext context) + { + context.Envelope!.TryGetHeader("X-Tenant-Id", out var tenantId); + return Guid.Parse(tenantId!); + } + + public void Handle( + ProcessOrder command, + [WriteAggregate(FromMethod = "ResolveTenantId")] TenantAggregate tenant) + { + // tenant loaded using the Guid returned by ResolveTenantId() + } +} +``` + +::: warning +`FromClaim` is only supported in HTTP endpoints and will throw an `InvalidOperationException` if +used in a message handler. +::: + diff --git a/docs/guide/http/marten.md b/docs/guide/http/marten.md index d3231573f..5241753cb 100644 --- a/docs/guide/http/marten.md +++ b/docs/guide/http/marten.md @@ -419,6 +419,86 @@ public static OrderShipped Ship( See [Overriding Version Discovery](/guide/durability/marten/event-sourcing.html#overriding-version-discovery) in the aggregate handler workflow documentation for more details and multi-stream examples. +## Custom Identity Resolution + +By default, the `[Aggregate]` attribute resolves the stream identity from route arguments, query string parameters, +or request body properties. Starting in 5.25, you can use additional value sources to resolve the aggregate identity from +headers, claims, or computed methods. These same properties are available on all `WolverineParameterAttribute` subclasses +(`[Aggregate]`, `[WriteAggregate]`, `[ReadAggregate]`, etc.). + +### From a Request Header + +Use `FromHeader` to resolve the identity from an HTTP request header: + +```cs +[WolverinePost("/orders/ship")] +[EmptyResponse] +public static OrderShipped Ship( + ShipOrder command, + [Aggregate(FromHeader = "X-Order-Id")] Order order) +{ + return new OrderShipped(); +} +``` + +In message handlers, `FromHeader` reads from `Envelope.Headers` instead. + +### From a Claim + +Use `FromClaim` to resolve the identity from the authenticated user's claims. This is only +supported in HTTP endpoints: + +```cs +[WolverinePost("/profile/update")] +[EmptyResponse] +public static ProfileUpdated UpdateProfile( + UpdateProfile command, + [Aggregate(FromClaim = "profile-id")] UserProfile profile) +{ + return new ProfileUpdated(); +} +``` + +### From a Static Method + +Use `FromMethod` to resolve the identity from a static method on the endpoint class. The method's +parameters are resolved via method injection (services, `ClaimsPrincipal`, etc.): + +```cs +public static class UpdateAccountConfigEndpoint +{ + // Wolverine discovers this method and calls it to resolve the aggregate ID + public static Guid ResolveId(ClaimsPrincipal user) + { + return AccountConfig.CompositeId(user.FindFirst("tenant")?.Value); + } + + [WolverinePost("/account/config/update")] + [EmptyResponse] + public static AccountConfigUpdated Handle( + UpdateAccountConfig command, + [Aggregate(FromMethod = "ResolveId")] AccountConfig config) + { + return new AccountConfigUpdated(); + } +} +``` + +### From a Route Argument + +Use `FromRoute` as a more explicit alternative to the constructor parameter: + +```cs +[WolverinePost("/orders/{orderId}/ship")] +[EmptyResponse] +public static OrderShipped Ship( + ShipOrder command, + [Aggregate(FromRoute = "orderId")] Order order) +{ + return new OrderShipped(); +} +``` + ## Reading the Latest Version of an Aggregate ::: info diff --git a/src/Http/Wolverine.Http.Tests/value_source_resolution.cs b/src/Http/Wolverine.Http.Tests/value_source_resolution.cs new file mode 100644 index 000000000..adff4ff0f --- /dev/null +++ b/src/Http/Wolverine.Http.Tests/value_source_resolution.cs @@ -0,0 +1,184 @@ +using System.Security.Claims; +using Alba; +using Shouldly; + +namespace Wolverine.Http.Tests; + +public class value_source_resolution : IntegrationContext +{ + public value_source_resolution(AppFixture fixture) : base(fixture) + { + } + + private static ClaimsPrincipal UserWithClaims(params Claim[] claims) + { + var identity = new ClaimsIdentity(claims, "TestAuth"); + return new ClaimsPrincipal(identity); + } + + #region Header tests + + [Fact] + public async Task from_header_resolves_string_value() + { + var result = await Scenario(x => + { + x.Get.Url("/test/from-header/string"); + x.WithRequestHeader("X-Custom-Value", "hello-world"); + }); + + result.ReadAsText().ShouldBe("hello-world"); + } + + [Fact] + public async Task from_header_missing_string_returns_default() + { + var result = await Scenario(x => + { + x.Get.Url("/test/from-header/string"); + }); + + result.ReadAsText().ShouldBe("no-value"); + } + + [Fact] + public async Task from_header_resolves_int_value() + { + var result = await Scenario(x => + { + x.Get.Url("/test/from-header/int"); + x.WithRequestHeader("X-Count", "42"); + }); + + result.ReadAsText().ShouldBe("count:42"); + } + + [Fact] + public async Task from_header_resolves_guid_value() + { + var id = Guid.NewGuid(); + var result = await Scenario(x => + { + x.Get.Url("/test/from-header/guid"); + x.WithRequestHeader("X-Correlation-Id", id.ToString()); + }); + + result.ReadAsText().ShouldBe($"id:{id}"); + } + + [Fact] + public async Task from_header_int_missing_returns_default() + { + var result = await Scenario(x => + { + x.Get.Url("/test/from-header/int"); + }); + + result.ReadAsText().ShouldBe("count:0"); + } + + #endregion + + #region Claim tests + + [Fact] + public async Task from_claim_resolves_string_value() + { + var result = await Scenario(x => + { + x.Get.Url("/test/from-claim/string"); + x.ConfigureHttpContext(c => c.User = UserWithClaims(new Claim("sub", "user-123"))); + }); + + result.ReadAsText().ShouldBe("user-123"); + } + + [Fact] + public async Task from_claim_missing_string_returns_default() + { + var result = await Scenario(x => + { + x.Get.Url("/test/from-claim/string"); + }); + + result.ReadAsText().ShouldBe("no-user"); + } + + [Fact] + public async Task from_claim_resolves_int_value() + { + var result = await Scenario(x => + { + x.Get.Url("/test/from-claim/int"); + x.ConfigureHttpContext(c => c.User = UserWithClaims(new Claim("tenant-id", "42"))); + }); + + result.ReadAsText().ShouldBe("tenant:42"); + } + + [Fact] + public async Task from_claim_resolves_guid_value() + { + var id = Guid.NewGuid(); + var result = await Scenario(x => + { + x.Get.Url("/test/from-claim/guid"); + x.ConfigureHttpContext(c => c.User = UserWithClaims(new Claim("organization-id", id.ToString()))); + }); + + result.ReadAsText().ShouldBe($"org:{id}"); + } + + [Fact] + public async Task from_claim_int_missing_returns_default() + { + var result = await Scenario(x => + { + x.Get.Url("/test/from-claim/int"); + }); + + result.ReadAsText().ShouldBe("tenant:0"); + } + + #endregion + + #region Method tests + + [Fact] + public async Task from_method_resolves_guid_value() + { + var id = Guid.NewGuid(); + var result = await Scenario(x => + { + x.Get.Url("/test/from-method/guid"); + x.ConfigureHttpContext(c => c.User = UserWithClaims(new Claim("computed-id", id.ToString()))); + }); + + result.ReadAsText().ShouldBe($"resolved:{id}"); + } + + [Fact] + public async Task from_method_resolves_string_value() + { + var result = await Scenario(x => + { + x.Get.Url("/test/from-method/string"); + x.ConfigureHttpContext(c => c.User = UserWithClaims(new Claim("display-name", "Jeremy"))); + }); + + result.ReadAsText().ShouldBe("name:Jeremy"); + } + + [Fact] + public async Task from_method_with_no_claim_returns_default() + { + var result = await Scenario(x => + { + x.Get.Url("/test/from-method/string"); + }); + + result.ReadAsText().ShouldBe("name:anonymous"); + } + + #endregion +} diff --git a/src/Http/Wolverine.Http/CodeGen/ReadClaimFrame.cs b/src/Http/Wolverine.Http/CodeGen/ReadClaimFrame.cs new file mode 100644 index 000000000..8c20f7a41 --- /dev/null +++ b/src/Http/Wolverine.Http/CodeGen/ReadClaimFrame.cs @@ -0,0 +1,83 @@ +using System.Security.Claims; +using JasperFx.CodeGeneration; +using JasperFx.CodeGeneration.Frames; +using JasperFx.CodeGeneration.Model; +using JasperFx.Core.Reflection; + +namespace Wolverine.Http.CodeGen; + +/// +/// Code generation frame that reads a claim value from the ClaimsPrincipal. +/// Supports string and typed values via TryParse. +/// +internal class ReadClaimFrame : SyncFrame +{ + private readonly string _claimType; + private readonly Type _valueType; + private readonly bool _isNullable; + private readonly Type _rawType; + + public ReadClaimFrame(Type valueType, string claimType) + { + _claimType = claimType; + _valueType = valueType; + _isNullable = valueType.IsNullable(); + _rawType = _isNullable ? valueType.GetInnerTypeFromNullable() : valueType; + Variable = new Variable(valueType, $"claim_{claimType.Replace("-", "_").Replace(":", "_").Replace("/", "_")}", this); + } + + public Variable Variable { get; } + + public override void GenerateCode(GeneratedMethod method, ISourceWriter writer) + { + if (_rawType == typeof(string)) + { + writeStringValue(writer); + } + else + { + writeTypedValue(writer); + } + + Next?.GenerateCode(method, writer); + } + + private void writeStringValue(ISourceWriter writer) + { + writer.Write( + $"var {Variable.Usage} = httpContext.User?.FindFirst(\"{_claimType}\")?.Value;"); + } + + private void writeTypedValue(ISourceWriter writer) + { + var typeName = _rawType.FullNameInCode(); + + writer.Write( + $"var {Variable.Usage}_raw = httpContext.User?.FindFirst(\"{_claimType}\")?.Value;"); + writer.Write($"{_valueType.FullNameInCode()} {Variable.Usage} = default;"); + + if (_rawType.IsEnum) + { + writer.Write( + $"BLOCK:if ({Variable.Usage}_raw != null && {typeName}.TryParse<{typeName}>({Variable.Usage}_raw, true, out var {Variable.Usage}_parsed))"); + } + else if (_rawType.IsBoolean()) + { + writer.Write( + $"BLOCK:if ({Variable.Usage}_raw != null && {typeName}.TryParse({Variable.Usage}_raw, out var {Variable.Usage}_parsed))"); + } + else + { + writer.Write( + $"BLOCK:if ({Variable.Usage}_raw != null && {typeName}.TryParse({Variable.Usage}_raw, System.Globalization.CultureInfo.InvariantCulture, out var {Variable.Usage}_parsed))"); + } + + writer.Write($"{Variable.Usage} = {Variable.Usage}_parsed;"); + writer.FinishBlock(); + } + + public override IEnumerable FindVariables(IMethodVariables chain) + { + yield break; + } +} diff --git a/src/Http/Wolverine.Http/HttpChain.ApiDescription.cs b/src/Http/Wolverine.Http/HttpChain.ApiDescription.cs index 8f200adfe..8ceef1b84 100644 --- a/src/Http/Wolverine.Http/HttpChain.ApiDescription.cs +++ b/src/Http/Wolverine.Http/HttpChain.ApiDescription.cs @@ -150,7 +150,22 @@ public override bool TryFindVariable(string valueName, ValueSource source, Type { return true; } - + + if (source == ValueSource.Header && FindHeaderVariable(valueType, valueName, out variable!)) + { + return true; + } + + if (source == ValueSource.Claim && FindClaimVariable(valueType, valueName, out variable!)) + { + return true; + } + + if (source == ValueSource.Method) + { + return tryFindMethodVariable(valueName, valueType, out variable!); + } + if (HasRequestType) { var requestType = InputType()!; @@ -164,16 +179,57 @@ public override bool TryFindVariable(string valueName, ValueSource source, Type if (RequestBodyVariable == null) throw new InvalidOperationException( "Requesting member access to the request body, but the request body is not (yet) set."); - + variable = new MemberAccessVariable(RequestBodyVariable, member); return true; } } - + variable = default!; return false; } + internal bool FindHeaderVariable(Type valueType, string headerName, out Variable variable) + { + var frame = new CodeGen.ReadHttpFrame(CodeGen.BindingSource.Header, valueType, headerName.Replace("-", "_")) + { + Key = headerName + }; + Middleware.Add(frame); + variable = frame.Variable; + return true; + } + + internal bool FindClaimVariable(Type valueType, string claimType, out Variable variable) + { + var frame = new CodeGen.ReadClaimFrame(valueType, claimType); + Middleware.Add(frame); + variable = frame.Variable; + return true; + } + + private bool tryFindMethodVariable(string methodName, Type returnType, out Variable variable) + { + var handlerTypes = HandlerCalls().Select(h => h.HandlerType).Distinct(); + foreach (var type in handlerTypes) + { + var method = type + .GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy) + .FirstOrDefault(m => m.Name.EqualsIgnoreCase(methodName) && m.ReturnType == returnType); + + if (method != null) + { + var call = new MethodCall(type, method); + Middleware.Add(call); + variable = call.ReturnVariable!; + return true; + } + } + + throw new InvalidOperationException( + $"Could not find a public static method '{methodName}' returning {returnType.FullNameInCode()} on endpoint types: {handlerTypes.Select(t => t.FullNameInCode()).Join(", ")}"); + } + private sealed record NormalizedResponseMetadata(int StatusCode, Type? Type, IEnumerable ContentTypes) { // if an attribute doesn't specific the content type, conform with OpenAPI internals and infer. diff --git a/src/Http/WolverineWebApi/ValueSourceEndpoints.cs b/src/Http/WolverineWebApi/ValueSourceEndpoints.cs new file mode 100644 index 000000000..9dd85f6c2 --- /dev/null +++ b/src/Http/WolverineWebApi/ValueSourceEndpoints.cs @@ -0,0 +1,122 @@ +using System.Reflection; +using System.Security.Claims; +using JasperFx.CodeGeneration; +using JasperFx.CodeGeneration.Model; +using Wolverine.Attributes; +using Wolverine.Configuration; +using Wolverine.Http; +using IServiceContainer = JasperFx.IServiceContainer; + +namespace WolverineWebApi; + +#region sample_FromValueSource_attribute + +/// +/// Simple test attribute that resolves a parameter value from the configured ValueSource. +/// Used for testing the various value source resolution mechanisms. +/// +[AttributeUsage(AttributeTargets.Parameter)] +public class FromValueSourceAttribute : WolverineParameterAttribute +{ + public FromValueSourceAttribute() + { + } + + public FromValueSourceAttribute(string argumentName) : base(argumentName) + { + } + + public override Variable Modify(IChain chain, ParameterInfo parameter, + IServiceContainer container, GenerationRules rules) + { + if (chain.TryFindVariable(ArgumentName ?? parameter.Name!, ValueSource, parameter.ParameterType, out var variable)) + { + return variable; + } + + throw new InvalidOperationException( + $"Could not resolve value for parameter '{parameter.Name}' using ValueSource.{ValueSource} with argument name '{ArgumentName}'"); + } +} + +#endregion + +#region sample_value_source_test_endpoints + +public static class ValueSourceFromHeaderEndpoint +{ + [WolverineGet("/test/from-header/string")] + public static string GetStringHeader( + [FromValueSource(FromHeader = "X-Custom-Value")] string value) + { + return value ?? "no-value"; + } + + [WolverineGet("/test/from-header/int")] + public static string GetIntHeader( + [FromValueSource(FromHeader = "X-Count")] int count) + { + return $"count:{count}"; + } + + [WolverineGet("/test/from-header/guid")] + public static string GetGuidHeader( + [FromValueSource(FromHeader = "X-Correlation-Id")] Guid correlationId) + { + return $"id:{correlationId}"; + } +} + +public static class ValueSourceFromClaimEndpoint +{ + [WolverineGet("/test/from-claim/string")] + public static string GetStringClaim( + [FromValueSource(FromClaim = "sub")] string userId) + { + return userId ?? "no-user"; + } + + [WolverineGet("/test/from-claim/int")] + public static string GetIntClaim( + [FromValueSource(FromClaim = "tenant-id")] int tenantId) + { + return $"tenant:{tenantId}"; + } + + [WolverineGet("/test/from-claim/guid")] + public static string GetGuidClaim( + [FromValueSource(FromClaim = "organization-id")] Guid orgId) + { + return $"org:{orgId}"; + } +} + +public static class ValueSourceFromMethodEndpoint +{ + public static Guid ResolveId(ClaimsPrincipal user) + { + var claim = user.FindFirstValue("computed-id"); + return claim != null ? Guid.Parse(claim) : Guid.Empty; + } + + [WolverineGet("/test/from-method/guid")] + public static string GetMethodValue( + [FromValueSource(FromMethod = "ResolveId")] Guid resolvedId) + { + return $"resolved:{resolvedId}"; + } + + public static string ComputeName(ClaimsPrincipal user) + { + return user.FindFirstValue("display-name") ?? "anonymous"; + } + + [WolverineGet("/test/from-method/string")] + public static string GetMethodString( + [FromValueSource(FromMethod = "ComputeName")] string name) + { + return $"name:{name}"; + } +} + +#endregion diff --git a/src/Testing/CoreTests/Runtime/Handlers/HandlerChain_TryFindVariable.cs b/src/Testing/CoreTests/Runtime/Handlers/HandlerChain_TryFindVariable.cs index 3efff8c78..255b2040c 100644 --- a/src/Testing/CoreTests/Runtime/Handlers/HandlerChain_TryFindVariable.cs +++ b/src/Testing/CoreTests/Runtime/Handlers/HandlerChain_TryFindVariable.cs @@ -1,3 +1,4 @@ +using JasperFx.CodeGeneration.Frames; using JasperFx.CodeGeneration.Services; using Wolverine.Attributes; using Wolverine.Runtime.Handlers; @@ -11,53 +12,201 @@ public class HandlerChain_TryFindVariable public void for_matching_member_name_on_name_and_type() { var chain = HandlerChain.For(x => x.Handle(null!), new HandlerGraph()); - + chain.TryFindVariable(nameof(CreateThing.Id), ValueSource.InputMember, typeof(Guid), out var variable) .ShouldBeTrue(); - + variable.ShouldBeOfType() .Member.Name.ShouldBe("Id"); } - + [Fact] public void miss_on_type() { var chain = HandlerChain.For(x => x.Handle(null!), new HandlerGraph()); - + chain.TryFindVariable(nameof(CreateThing.Id), ValueSource.InputMember, typeof(int), out var variable) .ShouldBeFalse(); } - + [Fact] public void miss_on_member_name() { var chain = HandlerChain.For(x => x.Handle(null!), new HandlerGraph()); - + chain.TryFindVariable("wrong", ValueSource.InputMember, typeof(Guid), out var variable) .ShouldBeFalse(); } - + [Fact] public void for_matching_member_name_on_name_and_type_and_anything_is_the_source() { var chain = HandlerChain.For(x => x.Handle(null!), new HandlerGraph()); - + chain.TryFindVariable(nameof(CreateThing.Id), ValueSource.Anything, typeof(Guid), out var variable) .ShouldBeTrue(); - + variable.ShouldBeOfType() .Member.Name.ShouldBe("Id"); } - + [Fact] public void miss_on_unsupported_value_sources() { var chain = HandlerChain.For(x => x.Handle(null!), new HandlerGraph()); - + chain.TryFindVariable(nameof(CreateThing.Id), ValueSource.RouteValue, typeof(Guid), out var variable) .ShouldBeFalse(); } + [Fact] + public void header_source_creates_envelope_header_variable() + { + var chain = HandlerChain.For(x => x.Handle(null!), new HandlerGraph()); + + chain.TryFindVariable("X-Tenant-Id", ValueSource.Header, typeof(string), out var variable) + .ShouldBeTrue(); + + variable.ShouldNotBeNull(); + variable.Creator.ShouldBeOfType(); + } + + [Fact] + public void header_source_with_typed_guid_value() + { + var chain = HandlerChain.For(x => x.Handle(null!), new HandlerGraph()); + + chain.TryFindVariable("X-Stream-Id", ValueSource.Header, typeof(Guid), out var variable) + .ShouldBeTrue(); + + variable.ShouldNotBeNull(); + variable.VariableType.ShouldBe(typeof(Guid)); + variable.Creator.ShouldBeOfType(); + } + + [Fact] + public void header_source_with_typed_int_value() + { + var chain = HandlerChain.For(x => x.Handle(null!), new HandlerGraph()); + + chain.TryFindVariable("X-Count", ValueSource.Header, typeof(int), out var variable) + .ShouldBeTrue(); + + variable.ShouldNotBeNull(); + variable.VariableType.ShouldBe(typeof(int)); + } + + [Fact] + public void anything_source_does_not_fall_back_to_header() + { + var chain = HandlerChain.For(x => x.Handle(null!), new HandlerGraph()); + + // Header requires explicit ValueSource.Header — Anything does not include it + // because Header always "succeeds" and would swallow all unmatched names + chain.TryFindVariable("X-Tenant", ValueSource.Anything, typeof(string), out var variable) + .ShouldBeFalse(); + } + + [Fact] + public void anything_source_prefers_input_member_over_header() + { + var chain = HandlerChain.For(x => x.Handle(null!), new HandlerGraph()); + + // "Name" matches a property on CreateThing, should use InputMember not Header + chain.TryFindVariable("Name", ValueSource.Anything, typeof(string), out var variable) + .ShouldBeTrue(); + + variable.ShouldBeOfType(); + } + + [Fact] + public void claim_source_throws_in_handler_context() + { + var chain = HandlerChain.For(x => x.Handle(null!), new HandlerGraph()); + + Should.Throw(() => + chain.TryFindVariable("sub", ValueSource.Claim, typeof(string), out _)) + .Message.ShouldContain("HTTP endpoints"); + } + + [Fact] + public void method_source_discovers_static_method_on_handler_type() + { + var chain = HandlerChain.For(x => x.Handle(null!), new HandlerGraph()); + + chain.TryFindVariable("ResolveId", ValueSource.Method, typeof(Guid), out var variable) + .ShouldBeTrue(); + + variable.ShouldNotBeNull(); + variable.VariableType.ShouldBe(typeof(Guid)); + } + + [Fact] + public void method_source_discovers_static_method_on_base_type() + { + var chain = HandlerChain.For(x => x.Handle(null!), new HandlerGraph()); + + chain.TryFindVariable("ResolveId", ValueSource.Method, typeof(Guid), out var variable) + .ShouldBeTrue(); + + variable.ShouldNotBeNull(); + variable.VariableType.ShouldBe(typeof(Guid)); + } + + [Fact] + public void method_source_throws_when_method_not_found() + { + var chain = HandlerChain.For(x => x.Handle(null!), new HandlerGraph()); + + Should.Throw(() => + chain.TryFindVariable("NonExistentMethod", ValueSource.Method, typeof(Guid), out _)) + .Message.ShouldContain("NonExistentMethod"); + } +} + +public class WolverineParameterAttribute_convenience_properties +{ + private class TestAttribute : WolverineParameterAttribute + { + public override JasperFx.CodeGeneration.Model.Variable Modify( + Wolverine.Configuration.IChain chain, System.Reflection.ParameterInfo parameter, + JasperFx.IServiceContainer container, JasperFx.CodeGeneration.GenerationRules rules) + { + throw new NotImplementedException(); + } + } + + [Fact] + public void from_header_sets_value_source_and_argument_name() + { + var att = new TestAttribute { FromHeader = "X-Tenant-Id" }; + att.ValueSource.ShouldBe(ValueSource.Header); + att.ArgumentName.ShouldBe("X-Tenant-Id"); + } + + [Fact] + public void from_route_sets_value_source_and_argument_name() + { + var att = new TestAttribute { FromRoute = "orderId" }; + att.ValueSource.ShouldBe(ValueSource.RouteValue); + att.ArgumentName.ShouldBe("orderId"); + } + + [Fact] + public void from_claim_sets_value_source_and_argument_name() + { + var att = new TestAttribute { FromClaim = "sub" }; + att.ValueSource.ShouldBe(ValueSource.Claim); + att.ArgumentName.ShouldBe("sub"); + } + + [Fact] + public void from_method_sets_value_source_and_argument_name() + { + var att = new TestAttribute { FromMethod = "ResolveId" }; + att.ValueSource.ShouldBe(ValueSource.Method); + att.ArgumentName.ShouldBe("ResolveId"); + } } public record CreateThing(Guid Id, string Name, string Color); @@ -69,3 +218,26 @@ public void Handle(CreateThing command) // Nothing } } + +public class HandlerWithStaticMethod +{ + public static Guid ResolveId() => Guid.NewGuid(); + + public void Handle(CreateThing command) + { + // Nothing + } +} + +public class BaseHandlerWithMethod +{ + public static Guid ResolveId() => Guid.NewGuid(); +} + +public class DerivedHandler : BaseHandlerWithMethod +{ + public void Handle(CreateThing command) + { + // Nothing + } +} diff --git a/src/Wolverine/Attributes/ModifyChainAttribute.cs b/src/Wolverine/Attributes/ModifyChainAttribute.cs index 61a545cc7..8d73faec8 100644 --- a/src/Wolverine/Attributes/ModifyChainAttribute.cs +++ b/src/Wolverine/Attributes/ModifyChainAttribute.cs @@ -37,7 +37,22 @@ public enum ValueSource /// /// The value should be sourced by a query string parameter of an HTTP request /// - FromQueryString + FromQueryString, + + /// + /// The value should be sourced by an HTTP request header or an Envelope header in message handlers + /// + Header, + + /// + /// The value should be sourced from a claim on the ClaimsPrincipal. Only supported in HTTP endpoints. + /// + Claim, + + /// + /// The value should be sourced from the return value of a named static method on the handler or endpoint class + /// + Method } #endregion \ No newline at end of file diff --git a/src/Wolverine/Attributes/WolverineParameterAttribute.cs b/src/Wolverine/Attributes/WolverineParameterAttribute.cs index dabdb24e0..922b14e55 100644 --- a/src/Wolverine/Attributes/WolverineParameterAttribute.cs +++ b/src/Wolverine/Attributes/WolverineParameterAttribute.cs @@ -28,13 +28,69 @@ protected WolverineParameterAttribute(string argumentName) } public string ArgumentName { get; set; } = null!; - + /// /// Where should the identity value for resolving this parameter come from? /// Default is a named member on the message type or HTTP request type (if one exists) /// public ValueSource ValueSource { get; set; } = ValueSource.InputMember; + /// + /// Resolve the value from an HTTP request header or an Envelope header in message handlers. + /// Sets ValueSource to Header and ArgumentName to the specified header name. + /// + public string? FromHeader + { + get => ValueSource == ValueSource.Header ? ArgumentName : null; + set + { + ValueSource = ValueSource.Header; + ArgumentName = value!; + } + } + + /// + /// Resolve the value from an HTTP route argument. + /// Sets ValueSource to RouteValue and ArgumentName to the specified route parameter name. + /// + public string? FromRoute + { + get => ValueSource == ValueSource.RouteValue ? ArgumentName : null; + set + { + ValueSource = ValueSource.RouteValue; + ArgumentName = value!; + } + } + + /// + /// Resolve the value from a claim on the ClaimsPrincipal. Only supported in HTTP endpoints. + /// Sets ValueSource to Claim and ArgumentName to the specified claim type. + /// + public string? FromClaim + { + get => ValueSource == ValueSource.Claim ? ArgumentName : null; + set + { + ValueSource = ValueSource.Claim; + ArgumentName = value!; + } + } + + /// + /// Resolve the value from the return value of a named static method on the handler or endpoint class. + /// Sets ValueSource to Method and ArgumentName to the specified method name. + /// + public string? FromMethod + { + get => ValueSource == ValueSource.Method ? ArgumentName : null; + set + { + ValueSource = ValueSource.Method; + ArgumentName = value!; + } + } + /// /// Called by Wolverine during bootstrapping to modify the code generation /// for an HTTP endpoint with the decorated parameter diff --git a/src/Wolverine/Envelope.cs b/src/Wolverine/Envelope.cs index 05bd56d4d..a0b5a7956 100644 --- a/src/Wolverine/Envelope.cs +++ b/src/Wolverine/Envelope.cs @@ -62,6 +62,21 @@ public Envelope(object message) internal set => _headers = value; } + /// + /// Try to read a header value by key without forcing dictionary allocation. + /// Returns true if the header exists and has a non-null value. + /// + public bool TryGetHeader(string key, out string? value) + { + if (_headers != null && _headers.TryGetValue(key, out value)) + { + return value != null; + } + + value = null; + return false; + } + #region sample_envelope_deliver_by_property /// diff --git a/src/Wolverine/Runtime/Handlers/HandlerChain.cs b/src/Wolverine/Runtime/Handlers/HandlerChain.cs index e02d94652..46651c664 100644 --- a/src/Wolverine/Runtime/Handlers/HandlerChain.cs +++ b/src/Wolverine/Runtime/Handlers/HandlerChain.cs @@ -382,6 +382,12 @@ public override void UseForResponse(MethodCall methodCall) public override bool TryFindVariable(string valueName, ValueSource source, Type valueType, out Variable variable) { + if (source == ValueSource.Claim) + { + throw new InvalidOperationException( + "ValueSource.Claim is only supported in HTTP endpoints, not message handlers. Use ValueSource.Header to read from Envelope headers instead."); + } + if (source == ValueSource.InputMember || source == ValueSource.Anything) { var member = (MemberInfo?)MessageType.GetProperties() @@ -396,10 +402,45 @@ public override bool TryFindVariable(string valueName, ValueSource source, Type } } + if (source == ValueSource.Header) + { + var frame = new ReadEnvelopeHeaderFrame(valueType, valueName); + Middleware.Add(frame); + variable = frame.Variable; + return true; + } + + if (source == ValueSource.Method) + { + return tryFindMethodVariable(valueName, valueType, out variable); + } + variable = default!; return false; } + private bool tryFindMethodVariable(string methodName, Type returnType, out Variable variable) + { + var handlerTypes = Handlers.Select(h => h.HandlerType).Distinct(); + foreach (var type in handlerTypes) + { + var method = type + .GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy) + .FirstOrDefault(m => m.Name.EqualsIgnoreCase(methodName) && m.ReturnType == returnType); + + if (method != null) + { + var call = new MethodCall(type, method); + Middleware.Add(call); + variable = call.ReturnVariable!; + return true; + } + } + + throw new InvalidOperationException( + $"Could not find a public static method '{methodName}' returning {returnType.FullNameInCode()} on handler types: {handlerTypes.Select(t => t.FullNameInCode()).Join(", ")}"); + } + public override Frame[] AddStopConditionIfNull(Variable variable) { var frame = typeof(EntityIsNotNullGuardFrame<>).CloseAndBuildAs(variable, variable.VariableType); diff --git a/src/Wolverine/Runtime/Handlers/ReadEnvelopeHeaderFrame.cs b/src/Wolverine/Runtime/Handlers/ReadEnvelopeHeaderFrame.cs new file mode 100644 index 000000000..388066520 --- /dev/null +++ b/src/Wolverine/Runtime/Handlers/ReadEnvelopeHeaderFrame.cs @@ -0,0 +1,83 @@ +using JasperFx.CodeGeneration; +using JasperFx.CodeGeneration.Frames; +using JasperFx.CodeGeneration.Model; +using JasperFx.Core.Reflection; + +namespace Wolverine.Runtime.Handlers; + +/// +/// Code generation frame that reads a header value from the message Envelope. +/// Supports string and typed values via TryParse. +/// +internal class ReadEnvelopeHeaderFrame : SyncFrame +{ + private readonly string _headerKey; + private readonly Type _valueType; + private readonly bool _isNullable; + private readonly Type _rawType; + + public ReadEnvelopeHeaderFrame(Type valueType, string headerKey) + { + _headerKey = headerKey; + _valueType = valueType; + _isNullable = valueType.IsNullable(); + _rawType = _isNullable ? valueType.GetInnerTypeFromNullable() : valueType; + Variable = new Variable(valueType, $"envelopeHeader_{headerKey.Replace("-", "_")}", this); + } + + public Variable Variable { get; } + + public override void GenerateCode(GeneratedMethod method, ISourceWriter writer) + { + if (_rawType == typeof(string)) + { + writeStringValue(writer); + } + else + { + writeTypedValue(writer); + } + + Next?.GenerateCode(method, writer); + } + + private void writeStringValue(ISourceWriter writer) + { + writer.Write( + $"context.Envelope!.TryGetHeader(\"{_headerKey}\", out var {Variable.Usage}_raw);"); + writer.Write($"var {Variable.Usage} = {Variable.Usage}_raw;"); + } + + private void writeTypedValue(ISourceWriter writer) + { + var typeName = _rawType.FullNameInCode(); + + writer.Write( + $"context.Envelope!.TryGetHeader(\"{_headerKey}\", out var {Variable.Usage}_raw);"); + writer.Write($"{_valueType.FullNameInCode()} {Variable.Usage} = default;"); + + if (_rawType.IsEnum) + { + writer.Write( + $"BLOCK:if ({Variable.Usage}_raw != null && {typeName}.TryParse<{typeName}>({Variable.Usage}_raw, true, out var {Variable.Usage}_parsed))"); + } + else if (_rawType.IsBoolean()) + { + writer.Write( + $"BLOCK:if ({Variable.Usage}_raw != null && {typeName}.TryParse({Variable.Usage}_raw, out var {Variable.Usage}_parsed))"); + } + else + { + writer.Write( + $"BLOCK:if ({Variable.Usage}_raw != null && {typeName}.TryParse({Variable.Usage}_raw, System.Globalization.CultureInfo.InvariantCulture, out var {Variable.Usage}_parsed))"); + } + + writer.Write($"{Variable.Usage} = {Variable.Usage}_parsed;"); + writer.FinishBlock(); + } + + public override IEnumerable FindVariables(IMethodVariables chain) + { + yield break; + } +}