diff --git a/TUnit.Core.SourceGenerator.Tests/CustomAttributeInheritanceTests.Test.verified.txt b/TUnit.Core.SourceGenerator.Tests/CustomAttributeInheritanceTests.Test.verified.txt
new file mode 100644
index 0000000000..87ba74892b
--- /dev/null
+++ b/TUnit.Core.SourceGenerator.Tests/CustomAttributeInheritanceTests.Test.verified.txt
@@ -0,0 +1,89 @@
+//
+#pragma warning disable
+
+#nullable enable
+namespace TUnit.Generated;
+[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute]
+[global::System.CodeDom.Compiler.GeneratedCode("TUnit", "VERSION_SCRUBBED")]
+internal static class TUnit_TestProject_CustomAttributeInheritanceTests__TestSource
+{
+ private static readonly global::TUnit.Core.ClassMetadata __classMetadata = global::TUnit.Core.ClassMetadata.GetOrAdd("TestsBase`1:global::TUnit.TestProject.CustomAttributeInheritanceTests", new global::TUnit.Core.ClassMetadata
+ {
+ Type = typeof(global::TUnit.TestProject.CustomAttributeInheritanceTests),
+ TypeInfo = new global::TUnit.Core.ConcreteType(typeof(global::TUnit.TestProject.CustomAttributeInheritanceTests)),
+ Name = "CustomAttributeInheritanceTests",
+ Namespace = "TUnit.TestProject",
+ Assembly = global::TUnit.Core.AssemblyMetadata.GetOrAdd("TestsBase`1", "TestsBase`1"),
+ Parameters = global::System.Array.Empty(),
+ Properties = global::System.Array.Empty(),
+ Parent = null
+ });
+ private static readonly global::System.Type __classType = typeof(global::TUnit.TestProject.CustomAttributeInheritanceTests);
+ private static readonly global::TUnit.Core.MethodMetadata __mm_0 = global::TUnit.Core.MethodMetadataFactory.Create("Test", __classType, typeof(void), __classMetadata);
+ private static global::TUnit.TestProject.CustomAttributeInheritanceTests __CreateInstance(global::System.Type[] typeArgs, object?[] args)
+ {
+ return new global::TUnit.TestProject.CustomAttributeInheritanceTests();
+ }
+ private static global::System.Threading.Tasks.ValueTask __Invoke(global::TUnit.TestProject.CustomAttributeInheritanceTests instance, int methodIndex, object?[] args, global::System.Threading.CancellationToken cancellationToken)
+ {
+ switch (methodIndex)
+ {
+ case 0:
+ {
+ try
+ {
+ instance.Test();
+ return default(global::System.Threading.Tasks.ValueTask);
+ }
+ catch (global::System.Exception ex)
+ {
+ return new global::System.Threading.Tasks.ValueTask(global::System.Threading.Tasks.Task.FromException(ex));
+ }
+ }
+ default:
+ throw new global::System.ArgumentOutOfRangeException(nameof(methodIndex));
+ }
+ }
+ private static global::System.Attribute[] __Attributes(int groupIndex)
+ {
+ switch (groupIndex)
+ {
+ case 0:
+ {
+ return
+ [
+ new global::TUnit.Core.TestAttribute(),
+ new global::TUnit.TestProject.Attributes.EngineTest(global::TUnit.TestProject.Attributes.ExpectedResult.Pass),
+ new global::TUnit.TestProject.Attributes.SkipNetFrameworkAttribute("Not supported on .NET Framework")
+ ];
+ }
+ default:
+ throw new global::System.ArgumentOutOfRangeException(nameof(groupIndex));
+ }
+ }
+ public static readonly global::TUnit.Core.TestEntry[] Entries = new global::TUnit.Core.TestEntry[]
+ {
+ new global::TUnit.Core.TestEntry
+ {
+ MethodName = "Test",
+ FullyQualifiedName = "TUnit.TestProject.CustomAttributeInheritanceTests.Test",
+ FilePath = @"",
+ LineNumber = 9,
+ Categories = global::System.Array.Empty(),
+ Properties = global::System.Array.Empty(),
+ HasDataSource = false,
+ RepeatCount = 0,
+ DependsOn = global::System.Array.Empty(),
+ MethodMetadata = __mm_0,
+ CreateInstance = __CreateInstance,
+ InvokeBody = __Invoke,
+ MethodIndex = 0,
+ CreateAttributes = __Attributes,
+ AttributeGroupIndex = 0,
+ },
+ };
+}
+internal static partial class TUnit_TestRegistration
+{
+ static readonly int _r_TUnit_TestProject_CustomAttributeInheritanceTests__TestSource = global::TUnit.Core.SourceRegistrar.RegisterEntries(static () => TUnit_TestProject_CustomAttributeInheritanceTests__TestSource.Entries);
+}
diff --git a/TUnit.Core.SourceGenerator.Tests/CustomAttributeInheritanceTests.cs b/TUnit.Core.SourceGenerator.Tests/CustomAttributeInheritanceTests.cs
new file mode 100644
index 0000000000..c6356e1a04
--- /dev/null
+++ b/TUnit.Core.SourceGenerator.Tests/CustomAttributeInheritanceTests.cs
@@ -0,0 +1,15 @@
+using TUnit.Core.SourceGenerator.Tests.Options;
+
+namespace TUnit.Core.SourceGenerator.Tests;
+
+internal class CustomAttributeInheritanceTests : TestsBase
+{
+ [Test]
+ public Task Test() => RunTest(Path.Combine(Git.RootDirectory.FullName,
+ "TUnit.TestProject",
+ "CustomAttributeInheritanceTests.cs"),
+ new RunTestOptions(),
+ async generatedFiles =>
+ {
+ });
+}
diff --git a/TUnit.Core.SourceGenerator.Tests/MethodDataSourceDrivenWithCancellationTokenTests.Test.verified.txt b/TUnit.Core.SourceGenerator.Tests/MethodDataSourceDrivenWithCancellationTokenTests.Test.verified.txt
index c0ff7c3057..764ded7ade 100644
--- a/TUnit.Core.SourceGenerator.Tests/MethodDataSourceDrivenWithCancellationTokenTests.Test.verified.txt
+++ b/TUnit.Core.SourceGenerator.Tests/MethodDataSourceDrivenWithCancellationTokenTests.Test.verified.txt
@@ -66,8 +66,7 @@ global::TUnit.Core.ParameterMetadataFactory.Create(typeof(global::System.Threadi
return
[
new global::TUnit.Core.TestAttribute(),
- new global::TUnit.TestProject.Attributes.EngineTest(global::TUnit.TestProject.Attributes.ExpectedResult.Pass),
- new global::System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessageAttribute("Usage", "TUnit0046:Return a `Func` rather than a ``")
+ new global::TUnit.TestProject.Attributes.EngineTest(global::TUnit.TestProject.Attributes.ExpectedResult.Pass)
];
}
default:
diff --git a/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.DotNet10_0.verified.txt b/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.DotNet10_0.verified.txt
index 964b414140..7b75d88de4 100644
--- a/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.DotNet10_0.verified.txt
+++ b/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.DotNet10_0.verified.txt
@@ -223,8 +223,7 @@ internal static class TUnit_TestProject_STAThreadTests__TestSource
new global::TUnit.Core.Executors.STAThreadExecutorAttribute(),
new global::TUnit.TestProject.Attributes.EngineTest(global::TUnit.TestProject.Attributes.ExpectedResult.Pass),
new global::TUnit.Core.RunOnAttribute(global::TUnit.Core.Enums.OS.Windows),
- new global::TUnit.Core.RepeatAttribute(100),
- new global::System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessageAttribute("Interoperability", "CA1416:Validate platform compatibility")
+ new global::TUnit.Core.RepeatAttribute(100)
];
}
case 1:
@@ -234,8 +233,7 @@ internal static class TUnit_TestProject_STAThreadTests__TestSource
new global::TUnit.Core.TestAttribute(),
new global::TUnit.TestProject.Attributes.EngineTest(global::TUnit.TestProject.Attributes.ExpectedResult.Pass),
new global::TUnit.Core.RunOnAttribute(global::TUnit.Core.Enums.OS.Windows),
- new global::TUnit.Core.RepeatAttribute(100),
- new global::System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessageAttribute("Interoperability", "CA1416:Validate platform compatibility")
+ new global::TUnit.Core.RepeatAttribute(100)
];
}
default:
diff --git a/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.DotNet8_0.verified.txt b/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.DotNet8_0.verified.txt
index 964b414140..7b75d88de4 100644
--- a/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.DotNet8_0.verified.txt
+++ b/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.DotNet8_0.verified.txt
@@ -223,8 +223,7 @@ internal static class TUnit_TestProject_STAThreadTests__TestSource
new global::TUnit.Core.Executors.STAThreadExecutorAttribute(),
new global::TUnit.TestProject.Attributes.EngineTest(global::TUnit.TestProject.Attributes.ExpectedResult.Pass),
new global::TUnit.Core.RunOnAttribute(global::TUnit.Core.Enums.OS.Windows),
- new global::TUnit.Core.RepeatAttribute(100),
- new global::System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessageAttribute("Interoperability", "CA1416:Validate platform compatibility")
+ new global::TUnit.Core.RepeatAttribute(100)
];
}
case 1:
@@ -234,8 +233,7 @@ internal static class TUnit_TestProject_STAThreadTests__TestSource
new global::TUnit.Core.TestAttribute(),
new global::TUnit.TestProject.Attributes.EngineTest(global::TUnit.TestProject.Attributes.ExpectedResult.Pass),
new global::TUnit.Core.RunOnAttribute(global::TUnit.Core.Enums.OS.Windows),
- new global::TUnit.Core.RepeatAttribute(100),
- new global::System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessageAttribute("Interoperability", "CA1416:Validate platform compatibility")
+ new global::TUnit.Core.RepeatAttribute(100)
];
}
default:
diff --git a/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.DotNet9_0.verified.txt b/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.DotNet9_0.verified.txt
index 964b414140..7b75d88de4 100644
--- a/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.DotNet9_0.verified.txt
+++ b/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.DotNet9_0.verified.txt
@@ -223,8 +223,7 @@ internal static class TUnit_TestProject_STAThreadTests__TestSource
new global::TUnit.Core.Executors.STAThreadExecutorAttribute(),
new global::TUnit.TestProject.Attributes.EngineTest(global::TUnit.TestProject.Attributes.ExpectedResult.Pass),
new global::TUnit.Core.RunOnAttribute(global::TUnit.Core.Enums.OS.Windows),
- new global::TUnit.Core.RepeatAttribute(100),
- new global::System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessageAttribute("Interoperability", "CA1416:Validate platform compatibility")
+ new global::TUnit.Core.RepeatAttribute(100)
];
}
case 1:
@@ -234,8 +233,7 @@ internal static class TUnit_TestProject_STAThreadTests__TestSource
new global::TUnit.Core.TestAttribute(),
new global::TUnit.TestProject.Attributes.EngineTest(global::TUnit.TestProject.Attributes.ExpectedResult.Pass),
new global::TUnit.Core.RunOnAttribute(global::TUnit.Core.Enums.OS.Windows),
- new global::TUnit.Core.RepeatAttribute(100),
- new global::System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessageAttribute("Interoperability", "CA1416:Validate platform compatibility")
+ new global::TUnit.Core.RepeatAttribute(100)
];
}
default:
diff --git a/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.Net4_7.verified.txt b/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.Net4_7.verified.txt
index 87807a5a45..008d7c89a1 100644
--- a/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.Net4_7.verified.txt
+++ b/TUnit.Core.SourceGenerator.Tests/STAThreadTests.Test.Net4_7.verified.txt
@@ -223,8 +223,7 @@ internal static class TUnit_TestProject_STAThreadTests__TestSource
new global::TUnit.Core.Executors.STAThreadExecutorAttribute(),
new global::TUnit.TestProject.Attributes.EngineTest(global::TUnit.TestProject.Attributes.ExpectedResult.Pass),
new global::TUnit.Core.RunOnAttribute(global::TUnit.Core.Enums.OS.Windows),
- new global::TUnit.Core.RepeatAttribute(100),
- new global::System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessageAttribute("Interoperability", "CA1416:Validate platform compatibility")
+ new global::TUnit.Core.RepeatAttribute(100)
];
}
case 1:
@@ -234,8 +233,7 @@ internal static class TUnit_TestProject_STAThreadTests__TestSource
new global::TUnit.Core.TestAttribute(),
new global::TUnit.TestProject.Attributes.EngineTest(global::TUnit.TestProject.Attributes.ExpectedResult.Pass),
new global::TUnit.Core.RunOnAttribute(global::TUnit.Core.Enums.OS.Windows),
- new global::TUnit.Core.RepeatAttribute(100),
- new global::System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessageAttribute("Interoperability", "CA1416:Validate platform compatibility")
+ new global::TUnit.Core.RepeatAttribute(100)
];
}
default:
diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs
index f4cdcc0167..d9cc1d2a7b 100644
--- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs
+++ b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs
@@ -5,9 +5,12 @@
namespace TUnit.Core.SourceGenerator.CodeGenerators.Writers;
-public class AttributeWriter(Compilation compilation, TUnit.Core.SourceGenerator.Helpers.WellKnownTypes wellKnownTypes)
+public class AttributeWriter(Compilation compilation)
{
+ private const string TUnitRootNamespace = "TUnit";
+
private readonly Dictionary _attributeObjectInitializerCache = new();
+ private readonly Dictionary _tunitRelatedCache = new(SymbolEqualityComparer.Default);
public void WriteAttributes(ICodeWriter sourceCodeWriter,
IEnumerable attributeDatas)
@@ -17,22 +20,13 @@ public void WriteAttributes(ICodeWriter sourceCodeWriter,
// Filter out attributes that we can write
foreach (var attributeData in attributeDatas)
{
- // Include attributes with syntax reference (from current compilation)
- // Include attributes without syntax reference (from other assemblies) as long as they have an AttributeClass
- if (attributeData.ApplicationSyntaxReference is null && attributeData.AttributeClass is null)
- {
- continue;
- }
-
- // Skip compiler-internal and assembly-level attributes
- if (ShouldSkipCompilerInternalAttribute(attributeData))
+ if (attributeData.AttributeClass is null)
{
continue;
}
- // Skip framework-specific attributes when targeting older frameworks
- // We determine this by checking if we can compile the attribute
- if (ShouldSkipFrameworkSpecificAttribute(attributeData))
+ // Only include attributes that are, inherit from, or implement a TUnit type
+ if (!IsTUnitRelatedAttribute(attributeData.AttributeClass))
{
continue;
}
@@ -185,80 +179,32 @@ public static void WriteAttributeWithoutSyntax(ICodeWriter sourceCodeWriter, Att
}
- private bool ShouldSkipFrameworkSpecificAttribute(AttributeData attributeData)
+ private bool IsTUnitRelatedAttribute(INamedTypeSymbol attributeClass)
{
- if (attributeData.AttributeClass == null)
+ if (_tunitRelatedCache.TryGetValue(attributeClass, out var cached))
{
- return false;
+ return cached;
}
- // Generic approach: Check if the attribute type is actually available in the target compilation
- // This works by seeing if we can resolve the type from the compilation's references
- var fullyQualifiedName = wellKnownTypes.GetDisplayString(attributeData.AttributeClass);
-
- // Check if this is a system/runtime attribute that might not exist on all frameworks
- if (fullyQualifiedName.StartsWith("System.") || fullyQualifiedName.StartsWith("Microsoft."))
- {
- // Try to get the type from the compilation
- // If it doesn't exist in the compilation's references, we should skip it
- var typeSymbol = wellKnownTypes.TryGet(fullyQualifiedName);
-
- // If the type doesn't exist in the compilation, skip it
- if (typeSymbol == null)
- {
- return true;
- }
+ var result = attributeClass.GetSelfAndBaseTypes().Any(IsInTUnitNamespace)
+ || attributeClass.AllInterfaces.Any(IsInTUnitNamespace);
- // Special handling for attributes that exist but may not be usable
- // For example, nullable attributes exist in the reference assemblies but not at runtime for .NET Framework
- if (IsNullableAttribute(fullyQualifiedName))
- {
- // Check if we're targeting .NET Framework by looking at references
- var isNetFramework = compilation.References.Any(r =>
- r.Display?.Contains("mscorlib") == true &&
- !r.Display.Contains("System.Runtime"));
-
- if (isNetFramework)
- {
- return true; // Skip nullable attributes on .NET Framework
- }
- }
- }
-
- return false;
- }
-
- private static bool IsNullableAttribute(string fullyQualifiedName)
- {
- return fullyQualifiedName.Contains("NullableAttribute") ||
- fullyQualifiedName.Contains("NullableContextAttribute") ||
- fullyQualifiedName.Contains("NullablePublicOnlyAttribute");
+ _tunitRelatedCache[attributeClass] = result;
+ return result;
}
- private bool ShouldSkipCompilerInternalAttribute(AttributeData attributeData)
+ private static bool IsInTUnitNamespace(INamedTypeSymbol type)
{
- if (attributeData.AttributeClass == null)
- {
- return false;
- }
-
- var fullyQualifiedName = wellKnownTypes.GetDisplayString(attributeData.AttributeClass);
-
- // Skip compiler-internal attributes that should never be re-emitted
- // System.Runtime.CompilerServices contains compiler-generated and structural metadata attributes
- if (fullyQualifiedName.StartsWith("System.Runtime.CompilerServices."))
- {
- return true;
- }
+ var ns = type.ContainingNamespace;
+ INamespaceSymbol? outermost = null;
- // Skip debugger attributes (compiler-generated for debugging support)
- if (fullyQualifiedName.StartsWith("System.Diagnostics.Debugger"))
+ while (ns is { IsGlobalNamespace: false })
{
- return true;
+ outermost = ns;
+ ns = ns.ContainingNamespace;
}
- // Skip ParamArrayAttribute (compiler-generated for params keyword)
- return fullyQualifiedName == "System.ParamArrayAttribute";
+ return outermost?.Name == TUnitRootNamespace;
}
}
diff --git a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs
index ed02d6b4a9..226f8a6119 100644
--- a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs
+++ b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs
@@ -36,7 +36,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
var wellKnownTypes = new WellKnownTypes(c);
return new CompilationContext(
(CSharpCompilation)c,
- new AttributeWriter(c, wellKnownTypes),
+ new AttributeWriter(c),
wellKnownTypes
);
});
diff --git a/TUnit.TestProject/CustomAttributeInheritanceTests.cs b/TUnit.TestProject/CustomAttributeInheritanceTests.cs
new file mode 100644
index 0000000000..8067da6138
--- /dev/null
+++ b/TUnit.TestProject/CustomAttributeInheritanceTests.cs
@@ -0,0 +1,14 @@
+using TUnit.Core;
+using TUnit.TestProject.Attributes;
+
+namespace TUnit.TestProject;
+
+[SkipNetFramework("Not supported on .NET Framework")]
+public class CustomAttributeInheritanceTests
+{
+ [Test]
+ [EngineTest(ExpectedResult.Pass)]
+ public void Test()
+ {
+ }
+}