diff --git a/Harmony/Internal/PatchTools.cs b/Harmony/Internal/PatchTools.cs index 88494d9b..1a2df22d 100644 --- a/Harmony/Internal/PatchTools.cs +++ b/Harmony/Internal/PatchTools.cs @@ -84,8 +84,16 @@ internal static MethodBase GetOriginalMethod(this HarmonyMethod attr) case MethodType.Enumerator: if (attr.methodName is null) return null; - var method = AccessTools.DeclaredMethod(attr.declaringType, attr.methodName, attr.argumentTypes); - return AccessTools.EnumeratorMoveNext(method); + var enumMethod = AccessTools.DeclaredMethod(attr.declaringType, attr.methodName, attr.argumentTypes); + return AccessTools.EnumeratorMoveNext(enumMethod); + +#if NET45_OR_GREATER + case MethodType.Async: + if (attr.methodName is null) + return null; + var asyncMethod = AccessTools.DeclaredMethod(attr.declaringType, attr.methodName, attr.argumentTypes); + return AccessTools.AsyncMoveNext(asyncMethod); +#endif } } catch (AmbiguousMatchException ex) diff --git a/Harmony/Public/Attributes.cs b/Harmony/Public/Attributes.cs index 4911f80f..edd077ac 100644 --- a/Harmony/Public/Attributes.cs +++ b/Harmony/Public/Attributes.cs @@ -17,8 +17,12 @@ public enum MethodType Constructor, /// This is a static constructor StaticConstructor, - /// This targets the MoveNext method of the enumerator result - Enumerator + /// This targets the MoveNext method of the enumerator result, that actually contains the method's implementation + Enumerator, +#if NET45_OR_GREATER + /// This targets the MoveNext method of the async state machine, that actually contains the method's implementation + Async +#endif } /// Specifies the type of argument diff --git a/Harmony/Tools/AccessTools.cs b/Harmony/Tools/AccessTools.cs index 4d4ba658..e22b9e89 100644 --- a/Harmony/Tools/AccessTools.cs +++ b/Harmony/Tools/AccessTools.cs @@ -8,6 +8,7 @@ using System.Reflection; using System.Reflection.Emit; using System.Runtime.Serialization; +using System.Runtime.CompilerServices; using System.Threading; namespace HarmonyLib @@ -466,9 +467,9 @@ public static MethodInfo Method(string typeColonName, Type[] parameters = null, return Method(info.type, info.name, parameters, generics); } - /// Gets the method of an enumerator method + /// Gets the method of an enumerator method /// Enumerator method that creates the enumerator - /// The internal method of the enumerator or null if no valid enumerator is detected + /// The internal method of the enumerator or null if no valid enumerator is detected public static MethodInfo EnumeratorMoveNext(MethodBase method) { if (method is null) @@ -498,6 +499,37 @@ public static MethodInfo EnumeratorMoveNext(MethodBase method) return Method(type, nameof(IEnumerator.MoveNext)); } +#if NET45_OR_GREATER + /// Gets the method of an async method's state machine + /// Async method that creates the state machine internally + /// The internal method of the async state machine or null if no valid async method is detected + public static MethodInfo AsyncMoveNext(MethodBase method) + { + if (method is null) + { + FileLog.Debug("AccessTools.AsyncMoveNext: method is null"); + return null; + } + + var asyncAttribute = method.GetCustomAttribute(); + if (asyncAttribute is null) + { + FileLog.Debug($"AccessTools.AsyncMoveNext: Could not find AsyncStateMachine for {method.FullDescription()}"); + return null; + } + + var asyncStateMachineType = asyncAttribute.StateMachineType; + var asyncMethodBody = DeclaredMethod(asyncStateMachineType, nameof(IAsyncStateMachine.MoveNext)); + if (asyncMethodBody is null) + { + FileLog.Debug($"AccessTools.AsyncMoveNext: Could not find async method body for {method.FullDescription()}"); + return null; + } + + return asyncMethodBody; + } +#endif + /// Gets the names of all method that are declared in a type /// The declaring class/type /// A list of method names