diff --git a/TUnit.Core.SourceGenerator.Tests/RepeatTests.Assembly_Level_Repeat.verified.txt b/TUnit.Core.SourceGenerator.Tests/RepeatTests.Assembly_Level_Repeat.verified.txt new file mode 100644 index 0000000000..d98b2ae554 --- /dev/null +++ b/TUnit.Core.SourceGenerator.Tests/RepeatTests.Assembly_Level_Repeat.verified.txt @@ -0,0 +1,205 @@ +// +#pragma warning disable + +#nullable enable +namespace TUnit.Generated; +[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute] +[global::System.CodeDom.Compiler.GeneratedCode("TUnit", "1.0.0.0")] +internal sealed class TUnit_TestProject_AssemblyRepeatTests_TestWithAssemblyRepeat_TestSource : global::TUnit.Core.Interfaces.SourceGenerator.ITestSource, global::TUnit.Core.Interfaces.SourceGenerator.ITestDescriptorSource +{ + public async global::System.Collections.Generic.IAsyncEnumerable GetTestsAsync(string testSessionId, [global::System.Runtime.CompilerServices.EnumeratorCancellation] global::System.Threading.CancellationToken cancellationToken = default) + { + var metadata = new global::TUnit.Core.TestMetadata + { + TestName = "TestWithAssemblyRepeat", + TestClassType = typeof(global::TUnit.TestProject.AssemblyRepeatTests), + TestMethodName = "TestWithAssemblyRepeat", + Dependencies = global::System.Array.Empty(), + AttributeFactory = static () => + [ + new global::TUnit.Core.TestAttribute(), + new global::TUnit.Core.RepeatAttribute(3) + ], + RepeatCount = 3, + DataSources = global::System.Array.Empty(), + ClassDataSources = global::System.Array.Empty(), + PropertyDataSources = global::System.Array.Empty(), + PropertyInjections = global::System.Array.Empty(), + InheritanceDepth = 0, + FilePath = @"", + LineNumber = 9, + MethodMetadata = new global::TUnit.Core.MethodMetadata + { + Type = typeof(global::TUnit.TestProject.AssemblyRepeatTests), + TypeInfo = new global::TUnit.Core.ConcreteType(typeof(global::TUnit.TestProject.AssemblyRepeatTests)), + Name = "TestWithAssemblyRepeat", + GenericTypeCount = 0, + ReturnType = typeof(void), + ReturnTypeInfo = new global::TUnit.Core.ConcreteType(typeof(void)), + Parameters = global::System.Array.Empty(), + Class = global::TUnit.Core.ClassMetadata.GetOrAdd("TestsBase`1:global::TUnit.TestProject.AssemblyRepeatTests", static () => + { + var classMetadata = new global::TUnit.Core.ClassMetadata + { + Type = typeof(global::TUnit.TestProject.AssemblyRepeatTests), + TypeInfo = new global::TUnit.Core.ConcreteType(typeof(global::TUnit.TestProject.AssemblyRepeatTests)), + Name = "AssemblyRepeatTests", + Namespace = "TUnit.TestProject", + Assembly = global::TUnit.Core.AssemblyMetadata.GetOrAdd("TestsBase`1", static () => new global::TUnit.Core.AssemblyMetadata { Name = "TestsBase`1" }), + Parameters = global::System.Array.Empty(), + Properties = global::System.Array.Empty(), + Parent = null + }; + return classMetadata; + }) + }, + InstanceFactory = (typeArgs, args) => new global::TUnit.TestProject.AssemblyRepeatTests(), + InvokeTypedTest = static (instance, args, cancellationToken) => + { + try + { + instance.TestWithAssemblyRepeat(); + return default(global::System.Threading.Tasks.ValueTask); + } + catch (global::System.Exception ex) + { + return new global::System.Threading.Tasks.ValueTask(global::System.Threading.Tasks.Task.FromException(ex)); + } + }, + }; + metadata.UseRuntimeDataGeneration(testSessionId); + yield return metadata; + yield break; + } + public global::System.Collections.Generic.IEnumerable EnumerateTestDescriptors() + { + yield return new global::TUnit.Core.TestDescriptor + { + TestId = "TUnit.TestProject.AssemblyRepeatTests.TestWithAssemblyRepeat", + ClassName = "AssemblyRepeatTests", + MethodName = "TestWithAssemblyRepeat", + FullyQualifiedName = "TUnit.TestProject.AssemblyRepeatTests.TestWithAssemblyRepeat", + FilePath = @"", + LineNumber = 9, + Categories = global::System.Array.Empty(), + Properties = global::System.Array.Empty(), + HasDataSource = false, + RepeatCount = 3, + DependsOn = global::System.Array.Empty(), + Materializer = GetTestsAsync + }; + } +} +internal static class TUnit_TestProject_AssemblyRepeatTests_TestWithAssemblyRepeat_ModuleInitializer +{ + [global::System.Runtime.CompilerServices.ModuleInitializer] + public static void Initialize() + { + global::TUnit.Core.SourceRegistrar.Register(typeof(global::TUnit.TestProject.AssemblyRepeatTests), new TUnit_TestProject_AssemblyRepeatTests_TestWithAssemblyRepeat_TestSource()); + } +} + + +// ===== FILE SEPARATOR ===== + +// +#pragma warning disable + +#nullable enable +namespace TUnit.Generated; +[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute] +[global::System.CodeDom.Compiler.GeneratedCode("TUnit", "1.0.0.0")] +internal sealed class TUnit_TestProject_AssemblyRepeatTests_TestWithMethodRepeatOverride_TestSource : global::TUnit.Core.Interfaces.SourceGenerator.ITestSource, global::TUnit.Core.Interfaces.SourceGenerator.ITestDescriptorSource +{ + public async global::System.Collections.Generic.IAsyncEnumerable GetTestsAsync(string testSessionId, [global::System.Runtime.CompilerServices.EnumeratorCancellation] global::System.Threading.CancellationToken cancellationToken = default) + { + var metadata = new global::TUnit.Core.TestMetadata + { + TestName = "TestWithMethodRepeatOverride", + TestClassType = typeof(global::TUnit.TestProject.AssemblyRepeatTests), + TestMethodName = "TestWithMethodRepeatOverride", + Dependencies = global::System.Array.Empty(), + AttributeFactory = static () => + [ + new global::TUnit.Core.TestAttribute(), + new global::TUnit.Core.RepeatAttribute(1), + new global::TUnit.Core.RepeatAttribute(3) + ], + RepeatCount = 1, + DataSources = global::System.Array.Empty(), + ClassDataSources = global::System.Array.Empty(), + PropertyDataSources = global::System.Array.Empty(), + PropertyInjections = global::System.Array.Empty(), + InheritanceDepth = 0, + FilePath = @"", + LineNumber = 14, + MethodMetadata = new global::TUnit.Core.MethodMetadata + { + Type = typeof(global::TUnit.TestProject.AssemblyRepeatTests), + TypeInfo = new global::TUnit.Core.ConcreteType(typeof(global::TUnit.TestProject.AssemblyRepeatTests)), + Name = "TestWithMethodRepeatOverride", + GenericTypeCount = 0, + ReturnType = typeof(void), + ReturnTypeInfo = new global::TUnit.Core.ConcreteType(typeof(void)), + Parameters = global::System.Array.Empty(), + Class = global::TUnit.Core.ClassMetadata.GetOrAdd("TestsBase`1:global::TUnit.TestProject.AssemblyRepeatTests", static () => + { + var classMetadata = new global::TUnit.Core.ClassMetadata + { + Type = typeof(global::TUnit.TestProject.AssemblyRepeatTests), + TypeInfo = new global::TUnit.Core.ConcreteType(typeof(global::TUnit.TestProject.AssemblyRepeatTests)), + Name = "AssemblyRepeatTests", + Namespace = "TUnit.TestProject", + Assembly = global::TUnit.Core.AssemblyMetadata.GetOrAdd("TestsBase`1", static () => new global::TUnit.Core.AssemblyMetadata { Name = "TestsBase`1" }), + Parameters = global::System.Array.Empty(), + Properties = global::System.Array.Empty(), + Parent = null + }; + return classMetadata; + }) + }, + InstanceFactory = (typeArgs, args) => new global::TUnit.TestProject.AssemblyRepeatTests(), + InvokeTypedTest = static (instance, args, cancellationToken) => + { + try + { + instance.TestWithMethodRepeatOverride(); + return default(global::System.Threading.Tasks.ValueTask); + } + catch (global::System.Exception ex) + { + return new global::System.Threading.Tasks.ValueTask(global::System.Threading.Tasks.Task.FromException(ex)); + } + }, + }; + metadata.UseRuntimeDataGeneration(testSessionId); + yield return metadata; + yield break; + } + public global::System.Collections.Generic.IEnumerable EnumerateTestDescriptors() + { + yield return new global::TUnit.Core.TestDescriptor + { + TestId = "TUnit.TestProject.AssemblyRepeatTests.TestWithMethodRepeatOverride", + ClassName = "AssemblyRepeatTests", + MethodName = "TestWithMethodRepeatOverride", + FullyQualifiedName = "TUnit.TestProject.AssemblyRepeatTests.TestWithMethodRepeatOverride", + FilePath = @"", + LineNumber = 14, + Categories = global::System.Array.Empty(), + Properties = global::System.Array.Empty(), + HasDataSource = false, + RepeatCount = 1, + DependsOn = global::System.Array.Empty(), + Materializer = GetTestsAsync + }; + } +} +internal static class TUnit_TestProject_AssemblyRepeatTests_TestWithMethodRepeatOverride_ModuleInitializer +{ + [global::System.Runtime.CompilerServices.ModuleInitializer] + public static void Initialize() + { + global::TUnit.Core.SourceRegistrar.Register(typeof(global::TUnit.TestProject.AssemblyRepeatTests), new TUnit_TestProject_AssemblyRepeatTests_TestWithMethodRepeatOverride_TestSource()); + } +} diff --git a/TUnit.Core.SourceGenerator.Tests/RepeatTests.cs b/TUnit.Core.SourceGenerator.Tests/RepeatTests.cs index 09c47a0103..5ad35b8822 100644 --- a/TUnit.Core.SourceGenerator.Tests/RepeatTests.cs +++ b/TUnit.Core.SourceGenerator.Tests/RepeatTests.cs @@ -10,4 +10,47 @@ public Task Test() => RunTest(Path.Combine(Git.RootDirectory.FullName, async generatedFiles => { }); + + [Test] + public async Task Assembly_Level_Repeat() + { + var source = """ + using TUnit.Core; + + [assembly: Repeat(3)] + + namespace TUnit.TestProject; + + public class AssemblyRepeatTests + { + [Test] + public void TestWithAssemblyRepeat() + { + } + + [Test] + [Repeat(1)] + public void TestWithMethodRepeatOverride() + { + } + } + """; + + var tempFile = Path.GetTempFileName() + ".cs"; + await File.WriteAllTextAsync(tempFile, source); + + try + { + await TestMetadataGenerator.RunTest(tempFile, async generatedFiles => + { + }); + } + finally + { + if (File.Exists(tempFile)) + { + File.Delete(tempFile); + } + } + } } diff --git a/TUnit.Core.SourceGenerator/Analyzers/TestMethodAnalyzer.cs b/TUnit.Core.SourceGenerator/Analyzers/TestMethodAnalyzer.cs index 197b51a3e1..3bf4f646f4 100644 --- a/TUnit.Core.SourceGenerator/Analyzers/TestMethodAnalyzer.cs +++ b/TUnit.Core.SourceGenerator/Analyzers/TestMethodAnalyzer.cs @@ -116,7 +116,8 @@ private static (bool isSkipped, string? skipReason) ExtractSkipInfo(IMethodSymbo private static int ExtractRepeatCount(IMethodSymbol methodSymbol) { var repeatAttribute = methodSymbol.GetAttributes() - .FirstOrDefault(a => a.AttributeClass?.Name == "RepeatAttribute"); + .FirstOrDefault(a => a.AttributeClass?.Name == "RepeatAttribute" && + a.AttributeClass.ContainingNamespace?.ToDisplayString() == "TUnit.Core"); if (repeatAttribute is { ConstructorArguments.Length: > 0 }) { diff --git a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs index 354c645cd4..c8d0d8022f 100644 --- a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs @@ -2102,6 +2102,7 @@ private static int ExtractRepeatCount(TestMethodMetadata testMethod) foreach (var attr in testMethod.MethodAttributes) { if (attr.AttributeClass?.Name == "RepeatAttribute" && + attr.AttributeClass.ContainingNamespace?.ToDisplayString() == "TUnit.Core" && attr.ConstructorArguments.Length > 0 && attr.ConstructorArguments[0].Value is int count) { @@ -2113,6 +2114,19 @@ private static int ExtractRepeatCount(TestMethodMetadata testMethod) foreach (var attr in testMethod.TypeSymbol.GetAttributes()) { if (attr.AttributeClass?.Name == "RepeatAttribute" && + attr.AttributeClass.ContainingNamespace?.ToDisplayString() == "TUnit.Core" && + attr.ConstructorArguments.Length > 0 && + attr.ConstructorArguments[0].Value is int count) + { + return count; + } + } + + // Check assembly attributes + foreach (var attr in testMethod.TypeSymbol.ContainingAssembly.GetAttributes()) + { + if (attr.AttributeClass?.Name == "RepeatAttribute" && + attr.AttributeClass.ContainingNamespace?.ToDisplayString() == "TUnit.Core" && attr.ConstructorArguments.Length > 0 && attr.ConstructorArguments[0].Value is int count) { @@ -5138,7 +5152,8 @@ private static void GenerateConcreteTestMetadataForNonGeneric( { // Check method-level RepeatAttribute first var repeatAttribute = methodSymbol.GetAttributes() - .FirstOrDefault(a => a.AttributeClass?.Name == "RepeatAttribute"); + .FirstOrDefault(a => a.AttributeClass?.Name == "RepeatAttribute" && + a.AttributeClass.ContainingNamespace?.ToDisplayString() == "TUnit.Core"); if (repeatAttribute?.ConstructorArguments.Length > 0 && repeatAttribute.ConstructorArguments[0].Value is int methodCount) @@ -5148,7 +5163,8 @@ private static void GenerateConcreteTestMetadataForNonGeneric( // Check class-level RepeatAttribute (can be inherited) var classRepeatAttr = typeSymbol.GetAttributesIncludingBaseTypes() - .FirstOrDefault(a => a.AttributeClass?.Name == "RepeatAttribute"); + .FirstOrDefault(a => a.AttributeClass?.Name == "RepeatAttribute" && + a.AttributeClass.ContainingNamespace?.ToDisplayString() == "TUnit.Core"); if (classRepeatAttr?.ConstructorArguments.Length > 0 && classRepeatAttr.ConstructorArguments[0].Value is int classCount) @@ -5156,6 +5172,17 @@ private static void GenerateConcreteTestMetadataForNonGeneric( return classCount; } + // Check assembly-level RepeatAttribute + var assemblyRepeatAttr = typeSymbol.ContainingAssembly.GetAttributes() + .FirstOrDefault(a => a.AttributeClass?.Name == "RepeatAttribute" && + a.AttributeClass.ContainingNamespace?.ToDisplayString() == "TUnit.Core"); + + if (assemblyRepeatAttr?.ConstructorArguments.Length > 0 + && assemblyRepeatAttr.ConstructorArguments[0].Value is int assemblyCount) + { + return assemblyCount; + } + // No repeat attribute found return null; } diff --git a/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs b/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs index cd5e8aacf7..57ec9605e5 100644 --- a/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs +++ b/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs @@ -1062,7 +1062,8 @@ private static Task BuildTestMetadata( GenericMethodTypeArguments = testMethod.IsGenericMethodDefinition ? null : testMethod.GetGenericArguments(), AttributeFactory = () => ReflectionAttributeExtractor.GetAllAttributes(testClass, testMethod), RepeatCount = testMethod.GetCustomAttribute()?.Times - ?? testClass.GetCustomAttribute()?.Times, + ?? testClass.GetCustomAttribute()?.Times + ?? testClass.Assembly.GetCustomAttribute()?.Times, PropertyInjections = PropertySourceRegistry.DiscoverInjectableProperties(testClass), InheritanceDepth = inheritanceDepth });