11using System . Collections . Immutable ;
22using Microsoft . CodeAnalysis ;
33using Microsoft . CodeAnalysis . CSharp . Syntax ;
4+ using TUnit . Core . SourceGenerator . CodeGenerators . Helpers ;
45using TUnit . Core . SourceGenerator . CodeGenerators . Writers ;
56using TUnit . Core . SourceGenerator . Extensions ;
67
@@ -136,6 +137,7 @@ public int GetHashCode(HookMethodMetadata? obj)
136137 var lineNumber = location . GetLineSpan ( ) . StartLinePosition . Line + 1 ;
137138
138139 var order = GetHookOrder ( hookAttribute ) ;
140+ var hookExecutor = GetHookExecutorType ( methodSymbol ) ;
139141
140142 return new HookMethodMetadata
141143 {
@@ -146,7 +148,8 @@ public int GetHashCode(HookMethodMetadata? obj)
146148 HookKind = hookKind ,
147149 HookType = hookType ,
148150 Order = order ,
149- Context = context
151+ Context = context ,
152+ HookExecutor = hookExecutor
150153 } ;
151154 }
152155
@@ -237,6 +240,62 @@ private static int GetHookOrder(AttributeData attribute)
237240 return 0 ;
238241 }
239242
243+ private static string ? GetHookExecutorType ( IMethodSymbol methodSymbol )
244+ {
245+ var hookExecutorAttribute = methodSymbol . GetAttributes ( )
246+ . FirstOrDefault ( a => a . AttributeClass ? . Name == "HookExecutorAttribute" ||
247+ ( a . AttributeClass ? . IsGenericType == true &&
248+ a . AttributeClass ? . ConstructedFrom ? . Name == "HookExecutorAttribute" ) ) ;
249+
250+ if ( hookExecutorAttribute == null )
251+ {
252+ return null ;
253+ }
254+
255+ // For generic HookExecutorAttribute<T>, get the type argument
256+ if ( hookExecutorAttribute . AttributeClass ? . IsGenericType == true )
257+ {
258+ var typeArg = hookExecutorAttribute . AttributeClass . TypeArguments . FirstOrDefault ( ) ;
259+ return typeArg ? . GloballyQualified ( ) ;
260+ }
261+
262+ // For non-generic HookExecutorAttribute(Type type), get the constructor argument
263+ var typeArgument = hookExecutorAttribute . ConstructorArguments . FirstOrDefault ( ) ;
264+ if ( typeArgument . Value is ITypeSymbol typeSymbol )
265+ {
266+ return typeSymbol . GloballyQualified ( ) ;
267+ }
268+
269+ return null ;
270+ }
271+
272+ private static string GetConcreteHookType ( string dictionaryName , bool isInstance )
273+ {
274+ if ( isInstance )
275+ {
276+ return "InstanceHookMethod" ;
277+ }
278+
279+ return dictionaryName switch
280+ {
281+ "BeforeClassHooks" => "BeforeClassHookMethod" ,
282+ "AfterClassHooks" => "AfterClassHookMethod" ,
283+ "BeforeAssemblyHooks" => "BeforeAssemblyHookMethod" ,
284+ "AfterAssemblyHooks" => "AfterAssemblyHookMethod" ,
285+ "BeforeTestSessionHooks" => "BeforeTestSessionHookMethod" ,
286+ "AfterTestSessionHooks" => "AfterTestSessionHookMethod" ,
287+ "BeforeTestDiscoveryHooks" => "BeforeTestDiscoveryHookMethod" ,
288+ "AfterTestDiscoveryHooks" => "AfterTestDiscoveryHookMethod" ,
289+ "BeforeEveryTestHooks" => "BeforeTestHookMethod" ,
290+ "AfterEveryTestHooks" => "AfterTestHookMethod" ,
291+ "BeforeEveryClassHooks" => "BeforeClassHookMethod" ,
292+ "AfterEveryClassHooks" => "AfterClassHookMethod" ,
293+ "BeforeEveryAssemblyHooks" => "BeforeAssemblyHookMethod" ,
294+ "AfterEveryAssemblyHooks" => "AfterAssemblyHookMethod" ,
295+ _ => throw new ArgumentException ( $ "Unknown dictionary name: { dictionaryName } ")
296+ } ;
297+ }
298+
240299 private static void GenerateHookRegistry ( SourceProductionContext context , ImmutableArray < HookMethodMetadata > hooks )
241300 {
242301 try
@@ -663,7 +722,8 @@ private static void GenerateHookDelegate(CodeWriter writer, HookMethodMetadata h
663722
664723 private static void GenerateHookListPopulation ( CodeWriter writer , string dictionaryName , string typeDisplay , List < HookMethodMetadata > hooks , bool isInstance )
665724 {
666- writer . AppendLine ( $ "global::TUnit.Core.Sources.{ dictionaryName } .GetOrAdd(typeof({ typeDisplay } ), _ => new global::System.Collections.Concurrent.ConcurrentBag<global::TUnit.Core.Hooks.{ ( isInstance ? "InstanceHookMethod" : $ "StaticHookMethod<{ GetContextType ( hooks . First ( ) . HookType ) } >") } >());") ;
725+ var hookType = GetConcreteHookType ( dictionaryName , isInstance ) ;
726+ writer . AppendLine ( $ "global::TUnit.Core.Sources.{ dictionaryName } .GetOrAdd(typeof({ typeDisplay } ), _ => new global::System.Collections.Concurrent.ConcurrentBag<global::TUnit.Core.Hooks.{ hookType } >());") ;
667727
668728 foreach ( var hook in hooks . OrderBy ( h => h . Order ) )
669729 {
@@ -679,7 +739,8 @@ private static void GenerateHookListPopulation(CodeWriter writer, string diction
679739 private static void GenerateAssemblyHookListPopulation ( CodeWriter writer , string dictionaryName , string assemblyVarName , List < HookMethodMetadata > hooks )
680740 {
681741 var assemblyVar = assemblyVarName . Replace ( "." , "_" ) + "_assembly" ;
682- writer . AppendLine ( $ "global::TUnit.Core.Sources.{ dictionaryName } .GetOrAdd({ assemblyVar } , _ => new global::System.Collections.Concurrent.ConcurrentBag<global::TUnit.Core.Hooks.StaticHookMethod<AssemblyHookContext>>());") ;
742+ var hookType = GetConcreteHookType ( dictionaryName , false ) ;
743+ writer . AppendLine ( $ "global::TUnit.Core.Sources.{ dictionaryName } .GetOrAdd({ assemblyVar } , _ => new global::System.Collections.Concurrent.ConcurrentBag<global::TUnit.Core.Hooks.{ hookType } >());") ;
683744
684745 foreach ( var hook in hooks . OrderBy ( h => h . Order ) )
685746 {
@@ -722,7 +783,7 @@ private static void GenerateHookObject(CodeWriter writer, HookMethodMetadata hoo
722783 writer . Append ( "MethodInfo = " ) ;
723784 SourceInformationWriter . GenerateMethodInformation ( writer , hook . Context . SemanticModel . Compilation , hook . TypeSymbol , hook . MethodSymbol , null , ',' ) ;
724785 writer . AppendLine ( ) ;
725- writer . AppendLine ( "HookExecutor = null! ," ) ;
786+ writer . AppendLine ( $ "HookExecutor = { HookExecutorHelper . GetHookExecutor ( hook . HookExecutor ) } ,") ;
726787 writer . AppendLine ( $ "Order = { hook . Order } ,") ;
727788 writer . AppendLine ( $ "Body = { delegateKey } _Body" + ( isInstance ? "" : "," ) ) ;
728789
@@ -835,4 +896,5 @@ public class HookMethodMetadata
835896 public required string HookType { get ; init ; }
836897 public required int Order { get ; init ; }
837898 public required GeneratorAttributeSyntaxContext Context { get ; init ; }
899+ public string ? HookExecutor { get ; init ; }
838900}
0 commit comments