diff --git a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs index 6729a470..ffc99553 100644 --- a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs +++ b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs @@ -55,7 +55,8 @@ internal abstract partial class MessageFormatterRpcMarshaledContextTracker private static readonly ConcurrentDictionary MarshaledTypes = new(); private static readonly JsonRpcTargetOptions RpcMarshalableInterfaceDefaultTargetOptions = new() { NotifyClientOfEvents = false, DisposeOnDisconnect = true }; private static readonly MethodInfo ReleaseMarshaledObjectMethodInfo = typeof(MessageFormatterRpcMarshaledContextTracker).GetMethod(nameof(ReleaseMarshaledObject), BindingFlags.NonPublic | BindingFlags.Instance)!; - private static readonly ConcurrentDictionary MarshalableOptionalInterfaces = new ConcurrentDictionary(); + private static readonly ConcurrentDictionary MarshalableOptionalInterfaces = new(); + private static readonly ConcurrentDictionary RpcMarshalableAttributeCache = new(); private readonly Dictionary marshaledObjects = new Dictionary(); private readonly JsonRpc jsonRpc; @@ -103,12 +104,13 @@ internal static bool TryGetMarshalOptionsForType( [NotNullWhen(true)] out JsonRpcTargetOptions? targetOptions, [NotNullWhen(true)] out RpcMarshalableAttribute? rpcMarshalableAttribute) { - if (TryGetMarshalOptionsForTypeHelper(typeShape.Type, defaultProxyOptions, out proxyOptions, out targetOptions, out rpcMarshalableAttribute)) + bool? helperResult = TryGetMarshalOptionsForTypeHelper(typeShape.Type, defaultProxyOptions, out proxyOptions, out targetOptions, out rpcMarshalableAttribute); + if (helperResult is not null) { - return true; + return helperResult.Value; } - if (typeShape.Type.GetCustomAttribute() is RpcMarshalableAttribute marshalableAttribute) + if (TryGetRpcMarshalableAttribute(typeShape.Type, out RpcMarshalableAttribute? marshalableAttribute)) { // Validation requires more trim annotations than our NativeAOT callers can provide. // And besides, analyzers should have called out any issues at compile-time. @@ -133,19 +135,23 @@ internal static bool TryGetMarshalOptionsForType( [NotNullWhen(true)] out JsonRpcTargetOptions? targetOptions, [NotNullWhen(true)] out RpcMarshalableAttribute? rpcMarshalableAttribute) { - if (TryGetMarshalOptionsForTypeHelper(type, defaultProxyOptions, out proxyOptions, out targetOptions, out rpcMarshalableAttribute)) + bool? helperResult = TryGetMarshalOptionsForTypeHelper(type, defaultProxyOptions, out proxyOptions, out targetOptions, out rpcMarshalableAttribute); + if (helperResult is not null) { - // Because events are not checked by the Nerdbank.MessagePack formatter, - // Remove this check when issues related to https://github.com/eiriktsarpalis/PolyType/issues/226 are resolved in this file. - if (type.GetEvents().Length > 0) + if (helperResult is true) { - throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, Resources.MarshalableInterfaceHasEvents, type.FullName)); + // Because events are not checked by the Nerdbank.MessagePack formatter, + // Remove this check when issues related to https://github.com/eiriktsarpalis/PolyType/issues/226 are resolved in this file. + if (type.GetEvents().Length > 0) + { + throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, Resources.MarshalableInterfaceHasEvents, type.FullName)); + } } - return true; + return helperResult.Value; } - if (type.GetCustomAttribute() is RpcMarshalableAttribute marshalableAttribute) + if (TryGetRpcMarshalableAttribute(type, out RpcMarshalableAttribute? marshalableAttribute)) { // Validation requires more trim annotations than our NativeAOT callers can provide. // And besides, analyzers should have called out any issues at compile-time. @@ -195,7 +201,7 @@ internal static RpcMarshalableOptionalInterfaceAttribute[] GetMarshalableOptiona foreach (RpcMarshalableOptionalInterfaceAttribute attribute in attributes) { - if (attribute.OptionalInterface.GetCustomAttribute() is null) + if (!TryGetRpcMarshalableAttribute(attribute.OptionalInterface, out _)) { throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, Resources.RpcMarshalableOptionalInterfaceMustBeMarshalable, attribute.OptionalInterface.FullName)); } @@ -423,7 +429,7 @@ internal MarshalToken GetToken( protected abstract RpcTargetMetadata GetRpcTargetMetadata([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] Type interfaceType); - private static bool TryGetMarshalOptionsForTypeHelper( + private static bool? TryGetMarshalOptionsForTypeHelper( Type type, JsonRpcProxyOptions defaultProxyOptions, [NotNullWhen(true)] out JsonRpcProxyOptions? proxyOptions, @@ -435,6 +441,7 @@ private static bool TryGetMarshalOptionsForTypeHelper( rpcMarshalableAttribute = null; if (type.IsInterface is false) { + // Definitely not. return false; } @@ -461,7 +468,8 @@ private static bool TryGetMarshalOptionsForTypeHelper( } } - return false; + // Unknown. Caller may want to do more work. + return null; } /// @@ -525,6 +533,12 @@ private static void ValidateMarshalableInterface( } } + private static bool TryGetRpcMarshalableAttribute(Type type, [NotNullWhen(true)] out RpcMarshalableAttribute? attribute) + { + attribute = RpcMarshalableAttributeCache.GetOrAdd(type, static type => type.GetCustomAttribute()); + return attribute is not null; + } + /// /// Releases memory associated with marshaled objects. ///