Skip to content

Commit c73fb3e

Browse files
authored
fix: Before(TestSession) ignores HookExecutor (#2751)
1 parent 455eb2f commit c73fb3e

File tree

6 files changed

+257
-530
lines changed

6 files changed

+257
-530
lines changed

TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System.Collections.Immutable;
22
using Microsoft.CodeAnalysis;
33
using Microsoft.CodeAnalysis.CSharp.Syntax;
4+
using TUnit.Core.SourceGenerator.CodeGenerators.Helpers;
45
using TUnit.Core.SourceGenerator.CodeGenerators.Writers;
56
using TUnit.Core.SourceGenerator.Extensions;
67

@@ -136,6 +137,7 @@ public int GetHashCode(HookMethodMetadata? obj)
136137
var lineNumber = location.GetLineSpan().StartLinePosition.Line + 1;
137138

138139
var order = GetHookOrder(hookAttribute);
140+
var hookExecutor = GetHookExecutorType(methodSymbol);
139141

140142
return new HookMethodMetadata
141143
{
@@ -146,7 +148,8 @@ public int GetHashCode(HookMethodMetadata? obj)
146148
HookKind = hookKind,
147149
HookType = hookType,
148150
Order = order,
149-
Context = context
151+
Context = context,
152+
HookExecutor = hookExecutor
150153
};
151154
}
152155

@@ -237,6 +240,62 @@ private static int GetHookOrder(AttributeData attribute)
237240
return 0;
238241
}
239242

243+
private static string? GetHookExecutorType(IMethodSymbol methodSymbol)
244+
{
245+
var hookExecutorAttribute = methodSymbol.GetAttributes()
246+
.FirstOrDefault(a => a.AttributeClass?.Name == "HookExecutorAttribute" ||
247+
(a.AttributeClass?.IsGenericType == true &&
248+
a.AttributeClass?.ConstructedFrom?.Name == "HookExecutorAttribute"));
249+
250+
if (hookExecutorAttribute == null)
251+
{
252+
return null;
253+
}
254+
255+
// For generic HookExecutorAttribute<T>, get the type argument
256+
if (hookExecutorAttribute.AttributeClass?.IsGenericType == true)
257+
{
258+
var typeArg = hookExecutorAttribute.AttributeClass.TypeArguments.FirstOrDefault();
259+
return typeArg?.GloballyQualified();
260+
}
261+
262+
// For non-generic HookExecutorAttribute(Type type), get the constructor argument
263+
var typeArgument = hookExecutorAttribute.ConstructorArguments.FirstOrDefault();
264+
if (typeArgument.Value is ITypeSymbol typeSymbol)
265+
{
266+
return typeSymbol.GloballyQualified();
267+
}
268+
269+
return null;
270+
}
271+
272+
private static string GetConcreteHookType(string dictionaryName, bool isInstance)
273+
{
274+
if (isInstance)
275+
{
276+
return "InstanceHookMethod";
277+
}
278+
279+
return dictionaryName switch
280+
{
281+
"BeforeClassHooks" => "BeforeClassHookMethod",
282+
"AfterClassHooks" => "AfterClassHookMethod",
283+
"BeforeAssemblyHooks" => "BeforeAssemblyHookMethod",
284+
"AfterAssemblyHooks" => "AfterAssemblyHookMethod",
285+
"BeforeTestSessionHooks" => "BeforeTestSessionHookMethod",
286+
"AfterTestSessionHooks" => "AfterTestSessionHookMethod",
287+
"BeforeTestDiscoveryHooks" => "BeforeTestDiscoveryHookMethod",
288+
"AfterTestDiscoveryHooks" => "AfterTestDiscoveryHookMethod",
289+
"BeforeEveryTestHooks" => "BeforeTestHookMethod",
290+
"AfterEveryTestHooks" => "AfterTestHookMethod",
291+
"BeforeEveryClassHooks" => "BeforeClassHookMethod",
292+
"AfterEveryClassHooks" => "AfterClassHookMethod",
293+
"BeforeEveryAssemblyHooks" => "BeforeAssemblyHookMethod",
294+
"AfterEveryAssemblyHooks" => "AfterAssemblyHookMethod",
295+
_ => throw new ArgumentException($"Unknown dictionary name: {dictionaryName}")
296+
};
297+
}
298+
240299
private static void GenerateHookRegistry(SourceProductionContext context, ImmutableArray<HookMethodMetadata> hooks)
241300
{
242301
try
@@ -663,7 +722,8 @@ private static void GenerateHookDelegate(CodeWriter writer, HookMethodMetadata h
663722

664723
private static void GenerateHookListPopulation(CodeWriter writer, string dictionaryName, string typeDisplay, List<HookMethodMetadata> hooks, bool isInstance)
665724
{
666-
writer.AppendLine($"global::TUnit.Core.Sources.{dictionaryName}.GetOrAdd(typeof({typeDisplay}), _ => new global::System.Collections.Concurrent.ConcurrentBag<global::TUnit.Core.Hooks.{(isInstance ? "InstanceHookMethod" : $"StaticHookMethod<{GetContextType(hooks.First().HookType)}>")}>());");
725+
var hookType = GetConcreteHookType(dictionaryName, isInstance);
726+
writer.AppendLine($"global::TUnit.Core.Sources.{dictionaryName}.GetOrAdd(typeof({typeDisplay}), _ => new global::System.Collections.Concurrent.ConcurrentBag<global::TUnit.Core.Hooks.{hookType}>());");
667727

668728
foreach (var hook in hooks.OrderBy(h => h.Order))
669729
{
@@ -679,7 +739,8 @@ private static void GenerateHookListPopulation(CodeWriter writer, string diction
679739
private static void GenerateAssemblyHookListPopulation(CodeWriter writer, string dictionaryName, string assemblyVarName, List<HookMethodMetadata> hooks)
680740
{
681741
var assemblyVar = assemblyVarName.Replace(".", "_") + "_assembly";
682-
writer.AppendLine($"global::TUnit.Core.Sources.{dictionaryName}.GetOrAdd({assemblyVar}, _ => new global::System.Collections.Concurrent.ConcurrentBag<global::TUnit.Core.Hooks.StaticHookMethod<AssemblyHookContext>>());");
742+
var hookType = GetConcreteHookType(dictionaryName, false);
743+
writer.AppendLine($"global::TUnit.Core.Sources.{dictionaryName}.GetOrAdd({assemblyVar}, _ => new global::System.Collections.Concurrent.ConcurrentBag<global::TUnit.Core.Hooks.{hookType}>());");
683744

684745
foreach (var hook in hooks.OrderBy(h => h.Order))
685746
{
@@ -722,7 +783,7 @@ private static void GenerateHookObject(CodeWriter writer, HookMethodMetadata hoo
722783
writer.Append("MethodInfo = ");
723784
SourceInformationWriter.GenerateMethodInformation(writer, hook.Context.SemanticModel.Compilation, hook.TypeSymbol, hook.MethodSymbol, null, ',');
724785
writer.AppendLine();
725-
writer.AppendLine("HookExecutor = null!,");
786+
writer.AppendLine($"HookExecutor = {HookExecutorHelper.GetHookExecutor(hook.HookExecutor)},");
726787
writer.AppendLine($"Order = {hook.Order},");
727788
writer.AppendLine($"Body = {delegateKey}_Body" + (isInstance ? "" : ","));
728789

@@ -835,4 +896,5 @@ public class HookMethodMetadata
835896
public required string HookType { get; init; }
836897
public required int Order { get; init; }
837898
public required GeneratorAttributeSyntaxContext Context { get; init; }
899+
public string? HookExecutor { get; init; }
838900
}

TUnit.Core/Sources.cs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,23 @@ public static class Sources
1515

1616
public static readonly ConcurrentDictionary<Type, ConcurrentBag<Hooks.InstanceHookMethod>> BeforeTestHooks = new();
1717
public static readonly ConcurrentDictionary<Type, ConcurrentBag<Hooks.InstanceHookMethod>> AfterTestHooks = new();
18-
public static readonly ConcurrentBag<Hooks.StaticHookMethod<TestContext>> BeforeEveryTestHooks = new();
19-
public static readonly ConcurrentBag<Hooks.StaticHookMethod<TestContext>> AfterEveryTestHooks = new();
18+
public static readonly ConcurrentBag<Hooks.BeforeTestHookMethod> BeforeEveryTestHooks = new();
19+
public static readonly ConcurrentBag<Hooks.AfterTestHookMethod> AfterEveryTestHooks = new();
2020

21-
public static readonly ConcurrentDictionary<Type, ConcurrentBag<Hooks.StaticHookMethod<ClassHookContext>>> BeforeClassHooks = new();
22-
public static readonly ConcurrentDictionary<Type, ConcurrentBag<Hooks.StaticHookMethod<ClassHookContext>>> AfterClassHooks = new();
23-
public static readonly ConcurrentBag<Hooks.StaticHookMethod<ClassHookContext>> BeforeEveryClassHooks = new();
24-
public static readonly ConcurrentBag<Hooks.StaticHookMethod<ClassHookContext>> AfterEveryClassHooks = new();
21+
public static readonly ConcurrentDictionary<Type, ConcurrentBag<Hooks.BeforeClassHookMethod>> BeforeClassHooks = new();
22+
public static readonly ConcurrentDictionary<Type, ConcurrentBag<Hooks.AfterClassHookMethod>> AfterClassHooks = new();
23+
public static readonly ConcurrentBag<Hooks.BeforeClassHookMethod> BeforeEveryClassHooks = new();
24+
public static readonly ConcurrentBag<Hooks.AfterClassHookMethod> AfterEveryClassHooks = new();
2525

26-
public static readonly ConcurrentDictionary<Assembly, ConcurrentBag<Hooks.StaticHookMethod<AssemblyHookContext>>> BeforeAssemblyHooks = new();
27-
public static readonly ConcurrentDictionary<Assembly, ConcurrentBag<Hooks.StaticHookMethod<AssemblyHookContext>>> AfterAssemblyHooks = new();
28-
public static readonly ConcurrentBag<Hooks.StaticHookMethod<AssemblyHookContext>> BeforeEveryAssemblyHooks = new();
29-
public static readonly ConcurrentBag<Hooks.StaticHookMethod<AssemblyHookContext>> AfterEveryAssemblyHooks = new();
26+
public static readonly ConcurrentDictionary<Assembly, ConcurrentBag<Hooks.BeforeAssemblyHookMethod>> BeforeAssemblyHooks = new();
27+
public static readonly ConcurrentDictionary<Assembly, ConcurrentBag<Hooks.AfterAssemblyHookMethod>> AfterAssemblyHooks = new();
28+
public static readonly ConcurrentBag<Hooks.BeforeAssemblyHookMethod> BeforeEveryAssemblyHooks = new();
29+
public static readonly ConcurrentBag<Hooks.AfterAssemblyHookMethod> AfterEveryAssemblyHooks = new();
3030

31-
public static readonly ConcurrentBag<Hooks.StaticHookMethod<TestSessionContext>> BeforeTestSessionHooks = [];
32-
public static readonly ConcurrentBag<Hooks.StaticHookMethod<TestSessionContext>> AfterTestSessionHooks = [];
33-
public static readonly ConcurrentBag<Hooks.StaticHookMethod<BeforeTestDiscoveryContext>> BeforeTestDiscoveryHooks = [];
34-
public static readonly ConcurrentBag<Hooks.StaticHookMethod<TestDiscoveryContext>> AfterTestDiscoveryHooks = [];
31+
public static readonly ConcurrentBag<Hooks.BeforeTestSessionHookMethod> BeforeTestSessionHooks = [];
32+
public static readonly ConcurrentBag<Hooks.AfterTestSessionHookMethod> AfterTestSessionHooks = [];
33+
public static readonly ConcurrentBag<Hooks.BeforeTestDiscoveryHookMethod> BeforeTestDiscoveryHooks = [];
34+
public static readonly ConcurrentBag<Hooks.AfterTestDiscoveryHookMethod> AfterTestDiscoveryHooks = [];
3535

3636
public static readonly ConcurrentQueue<Func<Task>> GlobalInitializers = [];
3737
public static readonly ConcurrentQueue<IPropertySource> PropertySources = [];

0 commit comments

Comments
 (0)