diff --git a/src/Plugins/RpcServer/RpcServer.cs b/src/Plugins/RpcServer/RpcServer.cs index f770149eaf..527d253027 100644 --- a/src/Plugins/RpcServer/RpcServer.cs +++ b/src/Plugins/RpcServer/RpcServer.cs @@ -21,6 +21,7 @@ using Neo.Plugins.RpcServer.Model; using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.IO; using System.IO.Compression; using System.Linq; @@ -40,7 +41,11 @@ public partial class RpcServer : IDisposable private const string HttpMethodGet = "GET"; private const string HttpMethodPost = "POST"; - private readonly Dictionary _methods = new(); + internal record struct RpcParameter(string Name, Type Type, bool Required, object? DefaultValue); + + private record struct RpcMethod(Delegate Delegate, RpcParameter[] Parameters); + + private readonly Dictionary _methods = new(); private IWebHost? host; private RpcServersSettings settings; @@ -324,9 +329,9 @@ public async Task ProcessAsync(HttpContext context) { (CheckAuth(context) && !settings.DisabledMethods.Contains(method)).True_Or(RpcError.AccessDenied); - if (_methods.TryGetValue(method, out var func)) + if (_methods.TryGetValue(method, out var rpcMethod)) { - response["result"] = ProcessParamsMethod(jsonParameters, func) switch + response["result"] = ProcessParamsMethod(jsonParameters, rpcMethod) switch { JToken result => result, Task task => await task, @@ -366,25 +371,24 @@ public async Task ProcessAsync(HttpContext context) } } - private object? ProcessParamsMethod(JArray arguments, Delegate func) + private object? ProcessParamsMethod(JArray arguments, RpcMethod rpcMethod) { - var parameterInfos = func.Method.GetParameters(); - var args = new object?[parameterInfos.Length]; + var args = new object?[rpcMethod.Parameters.Length]; // If the method has only one parameter of type JArray, invoke the method directly with the arguments - if (parameterInfos.Length == 1 && parameterInfos[0].ParameterType == typeof(JArray)) + if (rpcMethod.Parameters.Length == 1 && rpcMethod.Parameters[0].Type == typeof(JArray)) { - return func.DynamicInvoke(arguments); + return rpcMethod.Delegate.DynamicInvoke(arguments); } - for (var i = 0; i < parameterInfos.Length; i++) + for (var i = 0; i < rpcMethod.Parameters.Length; i++) { - var param = parameterInfos[i]; + var param = rpcMethod.Parameters[i]; if (arguments.Count > i && arguments[i] is not null) // Donot parse null values { try { - args[i] = ParameterConverter.AsParameter(arguments[i]!, param.ParameterType); + args[i] = ParameterConverter.AsParameter(arguments[i]!, param.Type); } catch (Exception e) when (e is not RpcException) { @@ -393,22 +397,13 @@ public async Task ProcessAsync(HttpContext context) } else { - if (param.IsOptional) - { - args[i] = param.DefaultValue; - } - else if (param.ParameterType.IsValueType && Nullable.GetUnderlyingType(param.ParameterType) == null) - { + if (param.Required) throw new ArgumentException($"Required parameter '{param.Name}' is missing"); - } - else - { - args[i] = null; - } + args[i] = param.DefaultValue; } } - return func.DynamicInvoke(args); + return rpcMethod.Delegate.DynamicInvoke(args); } public void RegisterMethods(object handler) @@ -420,11 +415,39 @@ public void RegisterMethods(object handler) if (rpcMethod is null) continue; var name = string.IsNullOrEmpty(rpcMethod.Name) ? method.Name.ToLowerInvariant() : rpcMethod.Name; - var parameters = method.GetParameters().Select(p => p.ParameterType).ToArray(); - var delegateType = Expression.GetDelegateType(parameters.Concat([method.ReturnType]).ToArray()); - - _methods[name] = Delegate.CreateDelegate(delegateType, handler, method); + var delegateParams = method.GetParameters() + .Select(p => p.ParameterType) + .Concat([method.ReturnType]) + .ToArray(); + var delegateType = Expression.GetDelegateType(delegateParams); + + _methods[name] = new RpcMethod( + Delegate.CreateDelegate(delegateType, handler, method), + method.GetParameters().Select(AsRpcParameter).ToArray() + ); } } + + static internal RpcParameter AsRpcParameter(ParameterInfo param) + { + // Required if not optional and not nullable + // For reference types, if parameter has not default value and nullable is disabled, it is optional. + // For value types, if parameter has not default value, it is required. + var required = param.IsOptional ? false : NotNullParameter(param); + return new RpcParameter(param.Name ?? string.Empty, param.ParameterType, required, param.DefaultValue); + } + + static private bool NotNullParameter(ParameterInfo param) + { + if (param.GetCustomAttribute() != null) return true; + if (param.GetCustomAttribute() != null) return true; + + if (param.GetCustomAttribute() != null) return false; + if (param.GetCustomAttribute() != null) return false; + + var context = new NullabilityInfoContext(); + var nullabilityInfo = context.Create(param); + return nullabilityInfo.WriteState == NullabilityState.NotNull; + } } } diff --git a/tests/Neo.Plugins.RpcServer.Tests/UT_RpcServer.cs b/tests/Neo.Plugins.RpcServer.Tests/UT_RpcServer.cs index 30947d34b9..055bbf8f79 100644 --- a/tests/Neo.Plugins.RpcServer.Tests/UT_RpcServer.cs +++ b/tests/Neo.Plugins.RpcServer.Tests/UT_RpcServer.cs @@ -20,6 +20,7 @@ using Neo.Wallets; using Neo.Wallets.NEP6; using System; +using System.Diagnostics.CodeAnalysis; using System.IO; using System.Linq; using System.Net; @@ -244,8 +245,28 @@ public async Task TestProcessRequest_MixedBatch() private class MockRpcMethods { +#nullable enable [RpcMethod] - internal JToken GetMockMethod() => "mock"; + public JToken GetMockMethod(string info) => $"string {info}"; + + public JToken NullContextMethod(string? info) => $"string-nullable {info}"; + + public JToken IntMethod(int info) => $"int {info}"; + + public JToken IntNullableMethod(int? info) => $"int-nullable {info}"; + + public JToken AllowNullMethod([AllowNull] string info) => $"string-allownull {info}"; +#nullable restore + +#nullable disable + public JToken NullableMethod(string info) => $"string-nullable {info}"; + + public JToken OptionalMethod(string info = "default") => $"string-default {info}"; + + public JToken NotNullMethod([NotNull] string info) => $"string-notnull {info}"; + + public JToken DisallowNullMethod([DisallowNull] string info) => $"string-disallownull {info}"; +#nullable restore } [TestMethod] @@ -256,7 +277,7 @@ public async Task TestRegisterMethods() // Request ProcessAsync with a valid request var context = new DefaultHttpContext(); var body = """ - {"jsonrpc": "2.0", "method": "getmockmethod", "params": [], "id": 1 } + {"jsonrpc": "2.0", "method": "getmockmethod", "params": ["test"], "id": 1 } """; context.Request.Method = "POST"; context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes(body)); @@ -276,10 +297,69 @@ public async Task TestRegisterMethods() // Parse the JSON response and check the result var responseJson = JToken.Parse(output); Assert.IsNotNull(responseJson["result"]); - Assert.AreEqual("mock", responseJson["result"].AsString()); + Assert.AreEqual("string test", responseJson["result"].AsString()); Assert.AreEqual(200, context.Response.StatusCode); } + [TestMethod] + public void TestNullableParameter() + { + var method = typeof(MockRpcMethods).GetMethod("GetMockMethod"); + var parameter = RpcServer.AsRpcParameter(method.GetParameters()[0]); + Assert.IsTrue(parameter.Required); + Assert.AreEqual(typeof(string), parameter.Type); + Assert.AreEqual("info", parameter.Name); + + method = typeof(MockRpcMethods).GetMethod("NullableMethod"); + parameter = RpcServer.AsRpcParameter(method.GetParameters()[0]); + Assert.IsFalse(parameter.Required); + Assert.AreEqual(typeof(string), parameter.Type); + Assert.AreEqual("info", parameter.Name); + + method = typeof(MockRpcMethods).GetMethod("NullContextMethod"); + parameter = RpcServer.AsRpcParameter(method.GetParameters()[0]); + Assert.IsFalse(parameter.Required); + Assert.AreEqual(typeof(string), parameter.Type); + Assert.AreEqual("info", parameter.Name); + + method = typeof(MockRpcMethods).GetMethod("OptionalMethod"); + parameter = RpcServer.AsRpcParameter(method.GetParameters()[0]); + Assert.IsFalse(parameter.Required); + Assert.AreEqual(typeof(string), parameter.Type); + Assert.AreEqual("info", parameter.Name); + Assert.AreEqual("default", parameter.DefaultValue); + + method = typeof(MockRpcMethods).GetMethod("IntMethod"); + parameter = RpcServer.AsRpcParameter(method.GetParameters()[0]); + Assert.IsTrue(parameter.Required); + Assert.AreEqual(typeof(int), parameter.Type); + Assert.AreEqual("info", parameter.Name); + + method = typeof(MockRpcMethods).GetMethod("IntNullableMethod"); + parameter = RpcServer.AsRpcParameter(method.GetParameters()[0]); + Assert.IsFalse(parameter.Required); + Assert.AreEqual(typeof(int?), parameter.Type); + Assert.AreEqual("info", parameter.Name); + + method = typeof(MockRpcMethods).GetMethod("NotNullMethod"); + parameter = RpcServer.AsRpcParameter(method.GetParameters()[0]); + Assert.IsTrue(parameter.Required); + Assert.AreEqual(typeof(string), parameter.Type); + Assert.AreEqual("info", parameter.Name); + + method = typeof(MockRpcMethods).GetMethod("AllowNullMethod"); + parameter = RpcServer.AsRpcParameter(method.GetParameters()[0]); + Assert.IsFalse(parameter.Required); + Assert.AreEqual(typeof(string), parameter.Type); + Assert.AreEqual("info", parameter.Name); + + method = typeof(MockRpcMethods).GetMethod("DisallowNullMethod"); + parameter = RpcServer.AsRpcParameter(method.GetParameters()[0]); + Assert.IsTrue(parameter.Required); + Assert.AreEqual(typeof(string), parameter.Type); + Assert.AreEqual("info", parameter.Name); + } + [TestMethod] public void TestRpcServerSettings_Load() {