diff --git a/TUnit.Core.SourceGenerator/Generators/AotMethodInvocationGenerator.cs b/TUnit.Core.SourceGenerator/Generators/AotMethodInvocationGenerator.cs
deleted file mode 100644
index 76caa1ace6..0000000000
--- a/TUnit.Core.SourceGenerator/Generators/AotMethodInvocationGenerator.cs
+++ /dev/null
@@ -1,717 +0,0 @@
-using System.Collections.Generic;
-using System.Collections.Immutable;
-using System.Linq;
-using System.Text;
-using Microsoft.CodeAnalysis;
-using Microsoft.CodeAnalysis.CSharp.Syntax;
-using TUnit.Core.SourceGenerator.CodeGenerators;
-using TUnit.Core.SourceGenerator.Extensions;
-
-namespace TUnit.Core.SourceGenerator.Generators;
-
-///
-/// Generates AOT-compatible method invocation code to replace MethodInfo.Invoke calls
-///
-[Generator]
-public sealed class AotMethodInvocationGenerator : IIncrementalGenerator
-{
- public void Initialize(IncrementalGeneratorInitializationContext context)
- {
- // Find all method data source attributes and methods that need invocation
- var methodDataSources = context.SyntaxProvider
- .CreateSyntaxProvider(
- predicate: (node, _) => IsMethodDataSourceUsage(node),
- transform: (ctx, _) => ExtractMethodDataSourceInfo(ctx))
- .Where(x => x is not null)
- .Select((x, _) => x!);
-
- // Find all MethodInfo.Invoke usage
- var methodInvocations = context.SyntaxProvider
- .CreateSyntaxProvider(
- predicate: (node, _) => IsMethodInfoInvocation(node),
- transform: (ctx, _) => ExtractMethodInvocationInfo(ctx))
- .Where(x => x is not null)
- .Select((x, _) => x!);
-
- // Combine all method invocation requirements
- var allMethodInfo = methodDataSources
- .Collect()
- .Combine(methodInvocations.Collect());
-
- // Generate the method invocation helpers
- context.RegisterSourceOutput(allMethodInfo, GenerateMethodInvokers);
- }
-
- private static bool IsMethodDataSourceUsage(SyntaxNode node)
- {
- // Look for MethodDataSourceAttribute usage
- if (node is AttributeSyntax attribute)
- {
- var name = attribute.Name.ToString();
- return name.Contains("MethodDataSource") || name.Contains("MethodDataSourceAttribute");
- }
-
- // Look for method declarations that could be data sources
- if (node is MethodDeclarationSyntax method)
- {
- // Check if method returns IEnumerable or similar
- var returnType = method.ReturnType?.ToString();
- if (returnType != null && (
- returnType.Contains("IEnumerable") ||
- returnType.Contains("IAsyncEnumerable") ||
- returnType.Contains("Task 0)
- {
- var firstArg = attribute.ArgumentList.Arguments[0];
- if (firstArg.Expression is LiteralExpressionSyntax literal)
- {
- methodName = literal.Token.ValueText;
- }
- }
-
- if (string.IsNullOrEmpty(methodName))
- {
- return null;
- }
-
- // Find the method in the containing type
- var containingClass = attribute.Ancestors().OfType().FirstOrDefault();
- if (containingClass == null)
- {
- return null;
- }
-
- var classSymbol = semanticModel.GetDeclaredSymbol(containingClass) as INamedTypeSymbol;
- var targetMethod = classSymbol?.GetMembers(methodName!)
- .OfType()
- .FirstOrDefault();
-
- if (targetMethod == null)
- {
- return null;
- }
-
- // Only include publicly accessible methods for AOT compatibility
- if (!IsAccessibleMethod(targetMethod))
- {
- return null;
- }
-
- return new MethodDataSourceInfo
- {
- TargetMethod = targetMethod,
- Location = attribute.GetLocation(),
- Usage = MethodUsage.DataSource
- };
- }
-
- private static MethodDataSourceInfo? ExtractFromMethod(MethodDeclarationSyntax method, SemanticModel semanticModel)
- {
- if (semanticModel.GetDeclaredSymbol(method) is not IMethodSymbol methodSymbol)
- {
- return null;
- }
-
- // Check if this method could be used as a data source
- var returnType = methodSymbol.ReturnType;
- if (!IsDataSourceReturnType(returnType))
- {
- return null;
- }
-
- // Only include publicly accessible methods for AOT compatibility
- if (!IsAccessibleMethod(methodSymbol))
- {
- return null;
- }
-
- return new MethodDataSourceInfo
- {
- TargetMethod = methodSymbol,
- Location = method.GetLocation(),
- Usage = MethodUsage.DataSource
- };
- }
-
- private static IMethodSymbol? ExtractTargetMethod(ExpressionSyntax methodInfoExpression, SemanticModel semanticModel)
- {
- // Try to extract method from common patterns:
- // - typeof(Class).GetMethod("MethodName")
- // - instance.GetType().GetMethod("MethodName")
-
- if (methodInfoExpression is InvocationExpressionSyntax { Expression: MemberAccessExpressionSyntax { Name.Identifier.ValueText: "GetMethod" } memberAccess } invocation)
- {
- // Extract method name from GetMethod call
- if (invocation.ArgumentList.Arguments.Count > 0 &&
- invocation.ArgumentList.Arguments[0].Expression is LiteralExpressionSyntax literal)
- {
- var methodName = literal.Token.ValueText;
-
- // Try to resolve the type
- ITypeSymbol? targetType = null;
-
- if (memberAccess.Expression is InvocationExpressionSyntax { Expression: IdentifierNameSyntax { Identifier.ValueText: "typeof" } } typeofInvocation)
- {
- // typeof(Class).GetMethod pattern
- if (typeofInvocation.ArgumentList.Arguments.Count > 0)
- {
- targetType = semanticModel.GetTypeInfo(typeofInvocation.ArgumentList.Arguments[0].Expression).Type;
- }
- }
- else
- {
- // instance.GetType().GetMethod pattern
- var getTypeInvocation = memberAccess.Expression as InvocationExpressionSyntax;
- if (getTypeInvocation?.Expression is MemberAccessExpressionSyntax { Name.Identifier.ValueText: "GetType" } getTypeMember)
- {
- targetType = semanticModel.GetTypeInfo(getTypeMember.Expression).Type;
- }
- }
-
- if (targetType != null)
- {
- var method = targetType.GetMembers(methodName)
- .OfType()
- .FirstOrDefault();
-
- // Only return publicly accessible methods
- if (method != null && IsAccessibleMethod(method))
- {
- return method;
- }
- }
- }
- }
-
- return null;
- }
-
- private static bool IsDataSourceReturnType(ITypeSymbol returnType)
- {
- var typeName = returnType.ToDisplayString();
- return typeName.Contains("IEnumerable") ||
- typeName.Contains("IAsyncEnumerable") ||
- (returnType is INamedTypeSymbol namedType &&
- namedType.AllInterfaces.Any(i => i.Name.Contains("IEnumerable")));
- }
-
- private static void GenerateMethodInvokers(SourceProductionContext context,
- (ImmutableArray dataSources, ImmutableArray invocations) data)
- {
- var (dataSources, invocations) = data;
-
- if (dataSources.IsEmpty && invocations.IsEmpty)
- {
- return;
- }
-
- var writer = new CodeWriter();
- writer.AppendLine("#nullable enable");
- writer.AppendLine();
- writer.AppendLine("using System;");
- writer.AppendLine("using System.Threading.Tasks;");
- writer.AppendLine("using System.Collections.Generic;");
- writer.AppendLine("using System.Collections;");
- writer.AppendLine();
- writer.AppendLine("namespace TUnit.Generated;");
- writer.AppendLine();
-
- GenerateMethodInvokerClass(writer, dataSources, invocations);
-
- context.AddSource("AotMethodInvokers.g.cs", writer.ToString());
- }
-
- private static void GenerateMethodInvokerClass(CodeWriter writer,
- ImmutableArray dataSources,
- ImmutableArray invocations)
- {
- writer.AppendLine("/// ");
- writer.AppendLine("/// AOT-compatible method invocation helpers to replace MethodInfo.Invoke");
- writer.AppendLine("/// ");
- writer.AppendLine("public static class AotMethodInvokers");
- writer.AppendLine("{");
- writer.Indent();
-
- // Generate registry
- GenerateMethodRegistry(writer, dataSources, invocations);
-
- // Generate invocation helper methods
- GenerateInvocationMethods(writer, dataSources, invocations);
-
- // Generate strongly-typed invokers for each method (avoid duplicates by invoker name)
- var allMethods = new HashSet(SymbolEqualityComparer.Default);
- foreach (var ds in dataSources)
- {
- if (!HasUnresolvedTypeParameters(ds.TargetMethod) && IsAccessibleMethod(ds.TargetMethod))
- {
- allMethods.Add(ds.TargetMethod);
- }
- }
- foreach (var inv in invocations)
- {
- if (!HasUnresolvedTypeParameters(inv.TargetMethod) && IsAccessibleMethod(inv.TargetMethod))
- {
- allMethods.Add(inv.TargetMethod);
- }
- }
-
- var processedInvokerNames = new HashSet();
- foreach (var method in allMethods)
- {
- var invokerName = GetInvokerMethodName(method);
- if (processedInvokerNames.Add(invokerName))
- {
- GenerateStronglyTypedInvoker(writer, method);
- }
- }
-
- writer.Unindent();
- writer.AppendLine("}");
- }
-
- private static void GenerateMethodRegistry(CodeWriter writer,
- ImmutableArray dataSources,
- ImmutableArray invocations)
- {
- writer.AppendLine("private static readonly Dictionary>> _methodInvokers = new()");
- writer.AppendLine("{");
- writer.Indent();
-
- var processedMethods = new HashSet();
-
- foreach (var ds in dataSources)
- {
- // Only include methods that will have implementations generated
- if (!HasUnresolvedTypeParameters(ds.TargetMethod) && IsAccessibleMethod(ds.TargetMethod))
- {
- var methodKey = GetMethodKey(ds.TargetMethod);
- if (processedMethods.Add(methodKey))
- {
- var invokerName = GetInvokerMethodName(ds.TargetMethod);
- writer.AppendLine($"[\"{methodKey}\"] = {invokerName},");
- }
- }
- }
-
- foreach (var inv in invocations)
- {
- // Only include methods that will have implementations generated
- if (!HasUnresolvedTypeParameters(inv.TargetMethod) && IsAccessibleMethod(inv.TargetMethod))
- {
- var methodKey = GetMethodKey(inv.TargetMethod);
- if (processedMethods.Add(methodKey))
- {
- var invokerName = GetInvokerMethodName(inv.TargetMethod);
- writer.AppendLine($"[\"{methodKey}\"] = {invokerName},");
- }
- }
- }
-
- writer.Unindent();
- writer.AppendLine("};");
- writer.AppendLine();
- }
-
- private static void GenerateInvocationMethods(CodeWriter writer,
- ImmutableArray dataSources,
- ImmutableArray invocations)
- {
- writer.AppendLine("/// ");
- writer.AppendLine("/// Invokes a method by key (AOT-safe replacement for MethodInfo.Invoke)");
- writer.AppendLine("/// ");
- writer.AppendLine("public static async Task