diff --git a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs index 5e62b2da..221fac6c 100644 --- a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs +++ b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs @@ -57,7 +57,8 @@ internal partial class MessageFormatterRpcMarshaledContextTracker private static readonly ConcurrentDictionary MarshaledTypes = new(); private static readonly JsonRpcTargetOptions RpcMarshalableInterfaceDefaultTargetOptions = new() { NotifyClientOfEvents = false, DisposeOnDisconnect = true }; private static readonly MethodInfo ReleaseMarshaledObjectMethodInfo = typeof(MessageFormatterRpcMarshaledContextTracker).GetMethod(nameof(ReleaseMarshaledObject), BindingFlags.NonPublic | BindingFlags.Instance)!; - private static readonly ConcurrentDictionary MarshalableOptionalInterfaces = new ConcurrentDictionary(); + private static readonly ConcurrentDictionary MarshalableOptionalInterfaces = new(); + private static readonly ConcurrentDictionary RpcMarshalableAttributeCache = new(); private readonly Dictionary marshaledObjects = new Dictionary(); private readonly JsonRpc jsonRpc; @@ -106,12 +107,13 @@ internal static bool TryGetMarshalOptionsForType( [NotNullWhen(true)] out JsonRpcTargetOptions? targetOptions, [NotNullWhen(true)] out RpcMarshalableAttribute? rpcMarshalableAttribute) { - if (TryGetMarshalOptionsForTypeHelper(typeShape.Type, defaultProxyOptions, out proxyOptions, out targetOptions, out rpcMarshalableAttribute)) + bool? helperResult = TryGetMarshalOptionsForTypeHelper(typeShape.Type, defaultProxyOptions, out proxyOptions, out targetOptions, out rpcMarshalableAttribute); + if (helperResult is not null) { - return true; + return helperResult.Value; } - if (typeShape.Type.GetCustomAttribute() is RpcMarshalableAttribute marshalableAttribute) + if (TryGetRpcMarshalableAttribute(typeShape.Type, out RpcMarshalableAttribute? marshalableAttribute)) { // Validation requires more trim annotations than our NativeAOT callers can provide. // And besides, analyzers should have called out any issues at compile-time. @@ -137,19 +139,23 @@ internal static bool TryGetMarshalOptionsForType( [NotNullWhen(true)] out JsonRpcTargetOptions? targetOptions, [NotNullWhen(true)] out RpcMarshalableAttribute? rpcMarshalableAttribute) { - if (TryGetMarshalOptionsForTypeHelper(type, defaultProxyOptions, out proxyOptions, out targetOptions, out rpcMarshalableAttribute)) + bool? helperResult = TryGetMarshalOptionsForTypeHelper(type, defaultProxyOptions, out proxyOptions, out targetOptions, out rpcMarshalableAttribute); + if (helperResult is not null) { - // Because events are not checked by the Nerdbank.MessagePack formatter, - // Remove this check when issues related to https://github.com/eiriktsarpalis/PolyType/issues/226 are resolved in this file. - if (type.GetEvents().Length > 0) + if (helperResult is true) { - throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, Resources.MarshalableInterfaceHasEvents, type.FullName)); + // Because events are not checked by the Nerdbank.MessagePack formatter, + // Remove this check when issues related to https://github.com/eiriktsarpalis/PolyType/issues/226 are resolved in this file. + if (type.GetEvents().Length > 0) + { + throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, Resources.MarshalableInterfaceHasEvents, type.FullName)); + } } - return true; + return helperResult.Value; } - if (type.GetCustomAttribute() is RpcMarshalableAttribute marshalableAttribute) + if (TryGetRpcMarshalableAttribute(type, out RpcMarshalableAttribute? marshalableAttribute)) { // Validation requires more trim annotations than our NativeAOT callers can provide. // And besides, analyzers should have called out any issues at compile-time. @@ -199,7 +205,7 @@ internal static RpcMarshalableOptionalInterfaceAttribute[] GetMarshalableOptiona foreach (RpcMarshalableOptionalInterfaceAttribute attribute in attributes) { - if (attribute.OptionalInterface.GetCustomAttribute() is null) + if (!TryGetRpcMarshalableAttribute(attribute.OptionalInterface, out _)) { throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, Resources.RpcMarshalableOptionalInterfaceMustBeMarshalable, attribute.OptionalInterface.FullName)); } @@ -441,7 +447,7 @@ internal MarshalToken GetToken( } } - private static bool TryGetMarshalOptionsForTypeHelper( + private static bool? TryGetMarshalOptionsForTypeHelper( Type type, JsonRpcProxyOptions defaultProxyOptions, [NotNullWhen(true)] out JsonRpcProxyOptions? proxyOptions, @@ -453,6 +459,7 @@ private static bool TryGetMarshalOptionsForTypeHelper( rpcMarshalableAttribute = null; if (type.IsInterface is false) { + // Definitely not. return false; } @@ -479,7 +486,8 @@ private static bool TryGetMarshalOptionsForTypeHelper( } } - return false; + // Unknown. Caller may want to do more work. + return null; } /// @@ -545,6 +553,12 @@ private static void ValidateMarshalableInterface( } #endif + private static bool TryGetRpcMarshalableAttribute(Type type, [NotNullWhen(true)] out RpcMarshalableAttribute? attribute) + { + attribute = RpcMarshalableAttributeCache.GetOrAdd(type, static type => type.GetCustomAttribute()); + return attribute is not null; + } + /// /// Releases memory associated with marshaled objects. ///