diff --git a/Harmony/Internal/PatchTools.cs b/Harmony/Internal/PatchTools.cs
index 88494d9b..1a2df22d 100644
--- a/Harmony/Internal/PatchTools.cs
+++ b/Harmony/Internal/PatchTools.cs
@@ -84,8 +84,16 @@ internal static MethodBase GetOriginalMethod(this HarmonyMethod attr)
case MethodType.Enumerator:
if (attr.methodName is null)
return null;
- var method = AccessTools.DeclaredMethod(attr.declaringType, attr.methodName, attr.argumentTypes);
- return AccessTools.EnumeratorMoveNext(method);
+ var enumMethod = AccessTools.DeclaredMethod(attr.declaringType, attr.methodName, attr.argumentTypes);
+ return AccessTools.EnumeratorMoveNext(enumMethod);
+
+#if NET45_OR_GREATER
+ case MethodType.Async:
+ if (attr.methodName is null)
+ return null;
+ var asyncMethod = AccessTools.DeclaredMethod(attr.declaringType, attr.methodName, attr.argumentTypes);
+ return AccessTools.AsyncMoveNext(asyncMethod);
+#endif
}
}
catch (AmbiguousMatchException ex)
diff --git a/Harmony/Public/Attributes.cs b/Harmony/Public/Attributes.cs
index 4911f80f..edd077ac 100644
--- a/Harmony/Public/Attributes.cs
+++ b/Harmony/Public/Attributes.cs
@@ -17,8 +17,12 @@ public enum MethodType
Constructor,
/// This is a static constructor
StaticConstructor,
- /// This targets the MoveNext method of the enumerator result
- Enumerator
+ /// This targets the MoveNext method of the enumerator result, that actually contains the method's implementation
+ Enumerator,
+#if NET45_OR_GREATER
+ /// This targets the MoveNext method of the async state machine, that actually contains the method's implementation
+ Async
+#endif
}
/// Specifies the type of argument
diff --git a/Harmony/Tools/AccessTools.cs b/Harmony/Tools/AccessTools.cs
index 4d4ba658..e22b9e89 100644
--- a/Harmony/Tools/AccessTools.cs
+++ b/Harmony/Tools/AccessTools.cs
@@ -8,6 +8,7 @@
using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.Serialization;
+using System.Runtime.CompilerServices;
using System.Threading;
namespace HarmonyLib
@@ -466,9 +467,9 @@ public static MethodInfo Method(string typeColonName, Type[] parameters = null,
return Method(info.type, info.name, parameters, generics);
}
- /// Gets the method of an enumerator method
+ /// Gets the method of an enumerator method
/// Enumerator method that creates the enumerator
- /// The internal method of the enumerator or null if no valid enumerator is detected
+ /// The internal method of the enumerator or null if no valid enumerator is detected
public static MethodInfo EnumeratorMoveNext(MethodBase method)
{
if (method is null)
@@ -498,6 +499,37 @@ public static MethodInfo EnumeratorMoveNext(MethodBase method)
return Method(type, nameof(IEnumerator.MoveNext));
}
+#if NET45_OR_GREATER
+ /// Gets the method of an async method's state machine
+ /// Async method that creates the state machine internally
+ /// The internal method of the async state machine or null if no valid async method is detected
+ public static MethodInfo AsyncMoveNext(MethodBase method)
+ {
+ if (method is null)
+ {
+ FileLog.Debug("AccessTools.AsyncMoveNext: method is null");
+ return null;
+ }
+
+ var asyncAttribute = method.GetCustomAttribute();
+ if (asyncAttribute is null)
+ {
+ FileLog.Debug($"AccessTools.AsyncMoveNext: Could not find AsyncStateMachine for {method.FullDescription()}");
+ return null;
+ }
+
+ var asyncStateMachineType = asyncAttribute.StateMachineType;
+ var asyncMethodBody = DeclaredMethod(asyncStateMachineType, nameof(IAsyncStateMachine.MoveNext));
+ if (asyncMethodBody is null)
+ {
+ FileLog.Debug($"AccessTools.AsyncMoveNext: Could not find async method body for {method.FullDescription()}");
+ return null;
+ }
+
+ return asyncMethodBody;
+ }
+#endif
+
/// Gets the names of all method that are declared in a type
/// The declaring class/type
/// A list of method names