diff --git a/TUnit.Core.SourceGenerator.Roslyn414/TUnit.Core.SourceGenerator.Roslyn414.csproj b/TUnit.Core.SourceGenerator.Roslyn414/TUnit.Core.SourceGenerator.Roslyn414.csproj
index 1f2708cc24..9e5718edba 100644
--- a/TUnit.Core.SourceGenerator.Roslyn414/TUnit.Core.SourceGenerator.Roslyn414.csproj
+++ b/TUnit.Core.SourceGenerator.Roslyn414/TUnit.Core.SourceGenerator.Roslyn414.csproj
@@ -4,6 +4,12 @@
4.14
+
+
+ CodeGenerators\Writers\Hooks
+
+
+
diff --git a/TUnit.Core.SourceGenerator.Roslyn44/TUnit.Core.SourceGenerator.Roslyn44.csproj b/TUnit.Core.SourceGenerator.Roslyn44/TUnit.Core.SourceGenerator.Roslyn44.csproj
index a321b8fc20..109e28962a 100644
--- a/TUnit.Core.SourceGenerator.Roslyn44/TUnit.Core.SourceGenerator.Roslyn44.csproj
+++ b/TUnit.Core.SourceGenerator.Roslyn44/TUnit.Core.SourceGenerator.Roslyn44.csproj
@@ -4,6 +4,12 @@
4.4
+
+
+ CodeGenerators\Writers\Hooks
+
+
+
diff --git a/TUnit.Core.SourceGenerator.Roslyn47/TUnit.Core.SourceGenerator.Roslyn47.csproj b/TUnit.Core.SourceGenerator.Roslyn47/TUnit.Core.SourceGenerator.Roslyn47.csproj
index 7a9abf09c6..c4571db44e 100644
--- a/TUnit.Core.SourceGenerator.Roslyn47/TUnit.Core.SourceGenerator.Roslyn47.csproj
+++ b/TUnit.Core.SourceGenerator.Roslyn47/TUnit.Core.SourceGenerator.Roslyn47.csproj
@@ -4,6 +4,12 @@
4.7
+
+
+ CodeGenerators\Writers\Hooks
+
+
+
diff --git a/TUnit.Core.SourceGenerator/Builders/ITestDefinitionBuilder.cs b/TUnit.Core.SourceGenerator/Builders/ITestDefinitionBuilder.cs
deleted file mode 100644
index 17dd95f525..0000000000
--- a/TUnit.Core.SourceGenerator/Builders/ITestDefinitionBuilder.cs
+++ /dev/null
@@ -1,19 +0,0 @@
-using TUnit.Core.SourceGenerator.Models;
-
-namespace TUnit.Core.SourceGenerator.Builders;
-
-///
-/// Interface for building test definitions
-///
-internal interface ITestDefinitionBuilder
-{
- ///
- /// Determines if this builder can handle the given context
- ///
- bool CanBuild(TestMetadataGenerationContext context);
-
- ///
- /// Builds test definitions and writes them to the code writer
- ///
- void BuildTestDefinitions(CodeWriter writer, TestMetadataGenerationContext context);
-}
diff --git a/TUnit.Core.SourceGenerator/CodeGenerationHelpers.cs b/TUnit.Core.SourceGenerator/CodeGenerationHelpers.cs
index daac82b9e0..4ebec87f72 100644
--- a/TUnit.Core.SourceGenerator/CodeGenerationHelpers.cs
+++ b/TUnit.Core.SourceGenerator/CodeGenerationHelpers.cs
@@ -2,9 +2,7 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using TUnit.Core.SourceGenerator.CodeGenerators.Helpers;
-using TUnit.Core.SourceGenerator.CodeGenerators.Writers;
using TUnit.Core.SourceGenerator.Extensions;
-using TUnit.Core.SourceGenerator.Utilities;
namespace TUnit.Core.SourceGenerator;
@@ -13,68 +11,6 @@ namespace TUnit.Core.SourceGenerator;
///
internal static class CodeGenerationHelpers
{
- ///
- /// Generates C# code for a ParameterMetadata array from method parameters.
- ///
- public static string GenerateParameterMetadataArray(IMethodSymbol method)
- {
- if (method.Parameters.Length == 0)
- {
- return "global::System.Array.Empty()";
- }
-
- using var writer = new CodeWriter("", includeHeader: false);
- writer.SetIndentLevel(2);
- using (writer.BeginArrayInitializer("new global::TUnit.Core.ParameterMetadata[]"))
- {
- foreach (var param in method.Parameters)
- {
- var parameterIndex = method.Parameters.IndexOf(param);
- var containsTypeParam = ContainsTypeParameter(param.Type);
- var typeForConstructor = containsTypeParam ? "object" : param.Type.GloballyQualified();
-
- using (writer.BeginObjectInitializer($"new global::TUnit.Core.ParameterMetadata(typeof({typeForConstructor}))", ","))
- {
- writer.AppendLine($"Name = \"{param.Name}\",");
- writer.AppendLine($"TypeInfo = {GenerateTypeInfo(param.Type)},");
- writer.AppendLine($"IsNullable = {param.Type.IsNullable().ToString().ToLowerInvariant()},");
- var paramTypesArray = GenerateParameterTypesArray(method);
- if (paramTypesArray == "null")
- {
- writer.AppendLine($"ReflectionInfo = typeof({method.ContainingType.GloballyQualified()}).GetMethods(global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.Static).FirstOrDefault(m => m.Name == \"{method.Name}\" && m.GetParameters().Length == {method.Parameters.Length})?.GetParameters()[{parameterIndex}],");
- }
- else
- {
- writer.AppendLine($"ReflectionInfo = typeof({method.ContainingType.GloballyQualified()}).GetMethod(\"{method.Name}\", global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.Static, null, {paramTypesArray}, null)!.GetParameters()[{parameterIndex}],");
- }
-
- // Generate cached data source attributes for AOT compatibility
- // Include both IDataSourceAttribute and IDataSourceMemberAttribute implementations
- var dataSourceAttributes = param.GetAttributes()
- .Where(attr => attr.AttributeClass != null &&
- attr.AttributeClass.AllInterfaces.Any(i =>
- i.Name == "IDataSourceAttribute" || i.Name == "IDataSourceMemberAttribute"))
- .ToArray();
-
- if (dataSourceAttributes.Length > 0)
- {
- writer.AppendLine($"CachedDataSourceAttributes = new global::System.Attribute[]");
- writer.AppendLine("{");
- writer.SetIndentLevel(3);
- foreach (var attr in dataSourceAttributes)
- {
- var attrCode = GenerateAttributeInstantiation(attr, method.Parameters);
- writer.AppendLine($"{attrCode},");
- }
- writer.SetIndentLevel(2);
- writer.Append("}");
- }
- }
- }
- }
- return writer.ToString().TrimEnd(); // Trim trailing newline for inline use
- }
-
///
/// Generates direct instantiation code for attributes.
///
@@ -289,7 +225,6 @@ private static bool IsParamsArrayArgument(AttributeData attr)
return typeName is "global::TUnit.Core.ArgumentsAttribute" or "global::TUnit.Core.InlineDataAttribute";
}
-
///
/// Determines if an attribute should be excluded from metadata.
///
@@ -320,120 +255,6 @@ public static bool ContainsTypeParameter(ITypeSymbol type)
return false;
}
- ///
- /// Gets a safe type name for use in typeof() expressions.
- /// Returns "object" only for actual type parameters or types containing them.
- /// Returns open generic forms (e.g., List<>) for generic type definitions.
- ///
-
-
- ///
- /// Generates C# code for PropertyMetadata array from class properties.
- ///
- public static string GeneratePropertyMetadataArray(INamedTypeSymbol typeSymbol)
- {
- var properties = typeSymbol.GetMembers()
- .OfType()
- .Where(p => p.DeclaredAccessibility == Accessibility.Public && !p.IsStatic)
- .ToList();
-
- if (properties.Count == 0)
- {
- return "global::System.Array.Empty()";
- }
-
- using var writer = new CodeWriter("", includeHeader: false);
- writer.SetIndentLevel(2);
- using (writer.BeginArrayInitializer("new global::TUnit.Core.PropertyMetadata[]"))
- {
- foreach (var prop in properties)
- {
- using (writer.BeginObjectInitializer("new global::TUnit.Core.PropertyMetadata", ","))
- {
- writer.AppendLine($"Name = \"{prop.Name}\",");
- writer.AppendLine($"Type = typeof({prop.Type.GloballyQualified()}),");
- writer.AppendLine($"ReflectionInfo = typeof({typeSymbol.GloballyQualified()}).GetProperty(\"{prop.Name}\"),");
- writer.AppendLine("IsStatic = false,");
- writer.AppendLine($"IsNullable = {prop.Type.IsNullable().ToString().ToLowerInvariant()},");
- writer.AppendLine($"Getter = obj => ((({typeSymbol.GloballyQualified()})obj).{prop.Name}),");
- writer.AppendLine("ClassMetadata = null!,");
- writer.AppendLine("ContainingTypeMetadata = null!");
- }
- }
- }
- return writer.ToString().TrimEnd(); // Trim trailing newline for inline use
- }
-
- ///
- /// Generates C# code for ConstructorMetadata array from class constructors.
- ///
-
- ///
- /// Generates C# code for class-level data source providers.
- ///
- public static string GenerateClassDataSourceProviders(INamedTypeSymbol typeSymbol)
- {
- var dataSourceAttributes = typeSymbol.GetAttributes()
- .Where(attr => IsDataSourceAttribute(attr))
- .ToList();
-
- if (dataSourceAttributes.Count == 0)
- {
- return "global::System.Array.Empty()";
- }
-
- using var writer = new CodeWriter("", includeHeader: false);
- writer.SetIndentLevel(2);
- using (writer.BeginArrayInitializer("new global::TUnit.Core.TestDataSource[]"))
- {
- foreach (var attr in dataSourceAttributes)
- {
- var providerCode = GenerateDataSourceProvider(attr, typeSymbol);
- if (!string.IsNullOrEmpty(providerCode))
- {
- writer.AppendLine($"{providerCode},");
- }
- }
- }
- return writer.ToString().TrimEnd(); // Trim trailing newline for inline use
- }
-
- ///
- /// Generates C# code for method-level data source providers.
- ///
- public static string GenerateMethodDataSourceProviders(IMethodSymbol methodSymbol)
- {
- var dataSourceAttributes = methodSymbol.GetAttributes()
- .Where(attr => IsDataSourceAttribute(attr))
- .ToList();
-
- // Also check method parameters for data attributes
- foreach (var param in methodSymbol.Parameters)
- {
- dataSourceAttributes.AddRange(param.GetAttributes().Where(IsDataSourceAttribute));
- }
-
- if (dataSourceAttributes.Count == 0)
- {
- return "global::System.Array.Empty()";
- }
-
- using var writer = new CodeWriter("", includeHeader: false);
- writer.SetIndentLevel(2);
- using (writer.BeginArrayInitializer("new global::TUnit.Core.TestDataSource[]"))
- {
- foreach (var attr in dataSourceAttributes)
- {
- var providerCode = GenerateDataSourceProvider(attr, methodSymbol.ContainingType);
- if (!string.IsNullOrEmpty(providerCode))
- {
- writer.AppendLine($"{providerCode},");
- }
- }
- }
- return writer.ToString().TrimEnd(); // Trim trailing newline for inline use
- }
-
///
/// Determines if an attribute is a data source attribute.
///
@@ -448,114 +269,6 @@ private static bool IsDataSourceAttribute(AttributeData attr)
return attr.AttributeClass.AllInterfaces.Any(i => i.GloballyQualified() == "global::TUnit.Core.IDataSourceAttribute");
}
- ///
- /// Generates a data source provider instance based on the attribute type.
- ///
- private static string GenerateDataSourceProvider(AttributeData attr, INamedTypeSymbol containingType)
- {
- var fullName = attr.AttributeClass!.GloballyQualified();
-
- switch (fullName)
- {
- case "TUnit.Core.ArgumentsAttribute":
- case "TUnit.Core.InlineDataAttribute":
- return GenerateInlineDataProvider(attr);
-
- case "TUnit.Core.MethodDataSourceAttribute":
- return GenerateMethodDataSourceProvider(attr, containingType);
-
- case "TUnit.Core.PropertyDataSourceAttribute":
- return GeneratePropertyDataSourceProvider(attr, containingType);
-
- default:
- // For custom IDataSourceAttribute implementations (including ClassDataSourceAttribute)
- return GenerateCustomDataProvider(attr);
- }
- }
-
- private static string GenerateInlineDataProvider(AttributeData attr)
- {
- using var writer = new CodeWriter("", includeHeader: false);
- writer.Append("new global::TUnit.Core.StaticTestDataSource(new object?[][] { new object?[] { ");
-
- var args = attr.ConstructorArguments.Select(arg => TypedConstantParser.GetRawTypedConstantValue(arg)).ToList();
- writer.Append(string.Join(", ", args));
- writer.Append(" } })");
- return writer.ToString().Trim();
- }
-
- private static string GenerateMethodDataSourceProvider(AttributeData attr, INamedTypeSymbol containingType)
- {
- if (attr.ConstructorArguments.Length == 0)
- {
- return "";
- }
-
- var methodName = attr.ConstructorArguments[0].Value?.ToString() ?? "";
- var isShared = attr.NamedArguments.FirstOrDefault(na => na.Key == "Shared").Value.Value as bool? ?? false;
-
- // Try to determine if this can be optimized for AOT
- var method = FindDataSourceMethod(methodName, containingType);
- if (method != null && ShouldUseAotOptimizedDataSource(method))
- {
- // Generate code that uses a more AOT-friendly approach
- return GenerateAotOptimizedDataSource(methodName, containingType, isShared);
- }
-
- // Fall back to DynamicTestDataSource for complex cases
- return $"new global::TUnit.Core.DynamicTestDataSource({isShared.ToString().ToLowerInvariant()}) {{ SourceType = typeof({containingType.GloballyQualified()}), SourceMemberName = \"{methodName}\" }}";
- }
-
- private static string GenerateParameterTypesArray(IMethodSymbol method)
- {
- if (method.Parameters.Length == 0)
- {
- return "global::System.Type.EmptyTypes";
- }
-
- if (method.Parameters.Any(p => ContainsTypeParameter(p.Type)))
- {
- return "null";
- }
-
- var parameterTypes = method.Parameters
- .Select(p => $"typeof({p.Type.GloballyQualified()})")
- .ToArray();
-
- return $"new global::System.Type[] {{ {string.Join(", ", parameterTypes)} }}";
- }
-
-
- private static string GeneratePropertyDataSourceProvider(AttributeData attr, INamedTypeSymbol containingType)
- {
- if (attr.ConstructorArguments.Length == 0)
- {
- return "";
- }
-
- var propertyName = attr.ConstructorArguments[0].Value?.ToString() ?? "";
- var isShared = attr.NamedArguments.FirstOrDefault(na => na.Key == "Shared").Value.Value as bool? ?? false;
-
- // Check if we can use AOT-friendly approach for property data sources
- var property = containingType.GetMembers(propertyName).OfType().FirstOrDefault();
- if (property != null && ShouldUseAotOptimizedPropertyDataSource(property))
- {
- return GenerateAotOptimizedPropertyDataSource(propertyName, containingType, isShared);
- }
-
- return $"new global::TUnit.Core.DynamicTestDataSource({isShared.ToString().ToLowerInvariant()}) {{ SourceType = typeof({containingType.GloballyQualified()}), SourceMemberName = \"{propertyName}\" }}";
- }
-
- private static string GenerateCustomDataProvider(AttributeData attr)
- {
- // For custom data attributes that implement IDataSourceAttribute (including AsyncDataSourceGeneratorAttribute),
- // we need to instantiate the attribute and use it directly
- var writer = new CodeWriter();
- AttributeWriter.WriteAttributeWithoutSyntax(writer, attr);
- return writer.ToString();
- }
-
-
///
/// Generates all test-related attributes for the TestMetadata.AttributesByType field as a dictionary.
///
@@ -668,122 +381,4 @@ private static int GetGenericParameterPosition(ITypeParameterSymbol typeParamete
}
return 0;
}
-
- private static string GetGenericTypeDefinitionName(INamedTypeSymbol namedType)
- {
- // Get the unbound generic type (e.g., List`1)
- var unboundType = namedType.ConstructUnboundGenericType();
- return GetAssemblyQualifiedName(unboundType);
- }
-
- private static string GetAssemblyQualifiedName(ITypeSymbol typeSymbol)
- {
- // Build assembly qualified name
- var typeName = typeSymbol.ToDisplayString(DisplayFormats.FullyQualifiedGenericWithoutGlobalPrefix);
-
- if (typeSymbol.ContainingAssembly.Name is "System.Private.CoreLib" or "mscorlib")
- {
- return $"{typeName}, System.Private.CoreLib";
- }
-
- return $"{typeName}, {typeSymbol.ContainingAssembly.Name}";
- }
-
- #region Compile-Time Data Source Resolution
-
- ///
- /// Finds a data source method in the containing type.
- ///
- private static IMethodSymbol? FindDataSourceMethod(string methodName, INamedTypeSymbol containingType)
- {
- return containingType.GetMembers(methodName)
- .OfType()
- .FirstOrDefault(m => m.Parameters.Length == 0);
- }
-
- ///
- /// Determines if a method should use AOT-optimized data source generation.
- ///
- private static bool ShouldUseAotOptimizedDataSource(IMethodSymbol method)
- {
- if (!method.IsStatic)
- {
- return false;
- }
- if (method.Parameters.Length > 0)
- {
- return false;
- }
-
- var returnType = method.ReturnType;
-
- if (returnType is INamedTypeSymbol namedType)
- {
- var typeString = namedType.ToDisplayString();
-
- return typeString.Contains("IEnumerable<") ||
- typeString.Contains("ICollection<") ||
- typeString.Contains("List<") ||
- typeString == "object[][]" ||
- typeString.Contains("object[]");
- }
-
- return false;
- }
-
- ///
- /// Generates AOT-optimized data source code that avoids reflection.
- ///
- private static string GenerateAotOptimizedDataSource(string methodName, INamedTypeSymbol containingType, bool isShared)
- {
- return $"new global::TUnit.Core.AotFriendlyTestDataSource({isShared.ToString().ToLowerInvariant()}) {{ " +
- $"MethodInvoker = () => {containingType.GloballyQualified()}.{methodName}(), " +
- $"SourceType = typeof({containingType.GloballyQualified()}), " +
- $"SourceMemberName = \"{methodName}\" }}";
- }
-
-
- ///
- /// Generates AOT-optimized property data source code that avoids reflection.
- ///
- private static string GenerateAotOptimizedPropertyDataSource(string propertyName, INamedTypeSymbol containingType, bool isShared)
- {
- return $"new global::TUnit.Core.AotFriendlyTestDataSource({isShared.ToString().ToLowerInvariant()}) {{ " +
- $"MethodInvoker = () => new {containingType.GloballyQualified()}().{propertyName}, " +
- $"SourceType = typeof({containingType.GloballyQualified()}), " +
- $"SourceMemberName = \"{propertyName}\" }}";
- }
-
- ///
- /// Determines if a property should use AOT-optimized data source generation.
- ///
- private static bool ShouldUseAotOptimizedPropertyDataSource(IPropertySymbol property)
- {
- if (!property.IsStatic)
- {
- var containingType = property.ContainingType;
- var hasParameterlessConstructor = containingType.Constructors.Any(c => c.Parameters.Length == 0);
- if (!hasParameterlessConstructor)
- {
- return false;
- }
- }
-
- var returnType = property.Type;
-
- if (returnType is INamedTypeSymbol namedType)
- {
- var typeString = namedType.ToDisplayString();
-
- return typeString.Contains("IEnumerable<") ||
- typeString.Contains("ICollection<") ||
- typeString.Contains("List<") ||
- typeString == "object[][]" ||
- typeString.Contains("object[]");
- }
-
- return false;
- }
-
- #endregion
}
diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TupleArgumentHelper.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TupleArgumentHelper.cs
index 37e93199f1..ec78c7f4d3 100644
--- a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TupleArgumentHelper.cs
+++ b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TupleArgumentHelper.cs
@@ -1,41 +1,10 @@
using Microsoft.CodeAnalysis;
using TUnit.Core.SourceGenerator.Extensions;
-using TUnit.Core.SourceGenerator.Models;
namespace TUnit.Core.SourceGenerator.CodeGenerators.Helpers;
public static class TupleArgumentHelper
{
- public static List GenerateArgumentAccess(ITypeSymbol parameterType, string argumentsArrayName, int baseIndex)
- {
- var argumentExpressions = new List();
-
- // For method parameters, tuples are NOT supported - the data source
- // must return already unpacked values matching the method signature
- var castExpression = $"global::TUnit.Core.Helpers.CastHelper.Cast<{parameterType.GloballyQualified()}>({argumentsArrayName}[{baseIndex}])";
- argumentExpressions.Add(castExpression);
-
- return argumentExpressions;
- }
-
- /// The types of all constructor parameters
- /// The name of the arguments array (e.g., "args")
- /// A list of argument access expressions for the constructor
- public static List GenerateConstructorArgumentAccess(IList parameterTypes, string argumentsArrayName)
- {
- var argumentExpressions = new List();
-
- // Data sources already provide unwrapped arguments, so we just access by index
- for (var i = 0; i < parameterTypes.Count; i++)
- {
- var parameterType = parameterTypes[i];
- var castExpression = $"global::TUnit.Core.Helpers.CastHelper.Cast<{parameterType.GloballyQualified()}>({argumentsArrayName}[{i}])";
- argumentExpressions.Add(castExpression);
- }
-
- return argumentExpressions;
- }
-
///
/// Generates method invocation arguments.
///
@@ -45,17 +14,17 @@ public static List GenerateConstructorArgumentAccess(IList
public static string GenerateMethodInvocationArguments(IList parameters, string argumentsArrayName)
{
var allArguments = new List();
-
+
for (var i = 0; i < parameters.Count; i++)
{
var parameter = parameters[i];
var castExpression = $"global::TUnit.Core.Helpers.CastHelper.Cast<{parameter.Type.GloballyQualified()}>({argumentsArrayName}[{i}])";
allArguments.Add(castExpression);
}
-
+
return string.Join(", ", allArguments);
}
-
+
///
/// Generates argument access for a method with possible params array, given a specific argument count.
///
@@ -83,10 +52,10 @@ public static List GenerateArgumentAccessWithParams(IList 0 && parameters[parameters.Count - 1].IsParams;
-
+
if (!hasParams)
{
// No params array - just cast each argument
@@ -114,7 +83,7 @@ public static List GenerateArgumentAccessWithParams(IList GenerateArgumentAccessWithParams(IList GenerateArgumentAccessWithParams(IList c.Parameters.Length > 0);
}
-}
\ No newline at end of file
+}
diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TypedConstantParser.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TypedConstantParser.cs
index 571c87a76c..fc38c7e5e2 100644
--- a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TypedConstantParser.cs
+++ b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TypedConstantParser.cs
@@ -1,8 +1,5 @@
using Microsoft.CodeAnalysis;
-using Microsoft.CodeAnalysis.CSharp;
-using Microsoft.CodeAnalysis.CSharp.Syntax;
using TUnit.Core.SourceGenerator.CodeGenerators.Formatting;
-using TUnit.Core.SourceGenerator.Extensions;
namespace TUnit.Core.SourceGenerator.CodeGenerators.Helpers;
@@ -10,66 +7,9 @@ public static class TypedConstantParser
{
private static readonly TypedConstantFormatter _formatter = new();
- public static string GetFullyQualifiedTypeNameFromTypedConstantValue(TypedConstant typedConstant)
- {
- if (typedConstant.Kind == TypedConstantKind.Type)
- {
- var type = (INamedTypeSymbol) typedConstant.Value!;
- return type.GloballyQualified();
- }
-
- if (typedConstant.Kind == TypedConstantKind.Enum)
- {
- return typedConstant.Type!.GloballyQualified();
- }
-
- if (typedConstant.Kind is not TypedConstantKind.Error and not TypedConstantKind.Array)
- {
- return $"global::{typedConstant.Value!.GetType().FullName}";
- }
-
- return typedConstant.Type!.GloballyQualified();
- }
-
public static string GetRawTypedConstantValue(TypedConstant typedConstant, ITypeSymbol? targetType = null)
{
// Use the formatter for consistent handling
return _formatter.FormatForCode(typedConstant, targetType);
}
-
- public static string FormatPrimitive(object? value)
- {
- // Check for special floating-point values first
- var specialFloatValue = SpecialFloatingPointValuesHelper.TryFormatSpecialFloatingPointValue(value);
- if (specialFloatValue != null)
- {
- return specialFloatValue;
- }
-
- switch (value)
- {
- case string s:
- return SymbolDisplay.FormatLiteral(s, quote: true);
- case char c:
- return SymbolDisplay.FormatLiteral(c, quote: true);
- case bool b:
- return b ? "true" : "false";
- case null:
- return "null";
- // Use InvariantCulture for numeric types to ensure consistent formatting
- case double d:
- return d.ToString(System.Globalization.CultureInfo.InvariantCulture) + "d";
- case float f:
- return f.ToString(System.Globalization.CultureInfo.InvariantCulture) + "f";
- case decimal dec:
- return dec.ToString(System.Globalization.CultureInfo.InvariantCulture) + "m";
- default:
- // For other numeric types, use InvariantCulture
- if (value is IFormattable formattable)
- {
- return formattable.ToString(null, System.Globalization.CultureInfo.InvariantCulture);
- }
- return value.ToString() ?? "null";
- }
- }
}
diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs
index ab2323c2a8..3485736116 100644
--- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs
+++ b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs
@@ -1,5 +1,4 @@
-using System.Collections.Immutable;
-using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using TUnit.Core.SourceGenerator.CodeGenerators.Helpers;
using TUnit.Core.SourceGenerator.Extensions;
diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/AssemblyHooksWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/AssemblyHooksWriter.cs
deleted file mode 100644
index 92c54b64e3..0000000000
--- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/AssemblyHooksWriter.cs
+++ /dev/null
@@ -1,59 +0,0 @@
-using TUnit.Core.SourceGenerator.CodeGenerators.Helpers;
-using TUnit.Core.SourceGenerator.Enums;
-using TUnit.Core.SourceGenerator.Models;
-
-namespace TUnit.Core.SourceGenerator.CodeGenerators.Writers.Hooks;
-
-public static class AssemblyHooksWriter
-{
- public static void Execute(ICodeWriter sourceBuilder, HooksDataModel? model)
- {
- if (model is null)
- {
- return;
- }
-
- if (model.HookLocationType == HookLocationType.Before)
- {
- sourceBuilder.Append("new global::TUnit.Core.Hooks.BeforeAssemblyHookMethod");
- }
- else
- {
- sourceBuilder.Append("new global::TUnit.Core.Hooks.AfterAssemblyHookMethod");
- }
-
- sourceBuilder.Append("{");
- sourceBuilder.Append("MethodInfo = ");
- SourceInformationWriter.GenerateMethodInformation(sourceBuilder, model.Context.SemanticModel.Compilation, model.ClassType, model.Method, null, ',');
-
- sourceBuilder.Append($"Body = (context, cancellationToken) => AsyncConvert.Convert(() => {model.FullyQualifiedTypeName}.{model.MethodName}({GetArgs(model)})),");
-
- sourceBuilder.Append($"HookExecutor = {HookExecutorHelper.GetHookExecutor(model.HookExecutor)},");
- sourceBuilder.Append($"Order = {model.Order},");
- sourceBuilder.Append($"RegistrationIndex = global::TUnit.Core.HookRegistrationIndices.GetNext{(model.HookLocationType == HookLocationType.Before ? "Before" : "After")}{(model.IsEveryHook ? "Every" : "")}AssemblyHookIndex(),");
- sourceBuilder.Append($"""FilePath = @"{model.FilePath}",""");
- sourceBuilder.Append($"LineNumber = {model.LineNumber},");
-
- sourceBuilder.Append("},");
- }
-
- private static string GetArgs(HooksDataModel model)
- {
- List args = [];
-
- foreach (var type in model.ParameterTypes)
- {
- if (type == WellKnownFullyQualifiedClassNames.AssemblyHookContext.WithGlobalPrefix)
- {
- args.Add("context");
- }
-
- if (type == WellKnownFullyQualifiedClassNames.CancellationToken.WithGlobalPrefix)
- {
- args.Add("cancellationToken");
- }
- }
-
- return string.Join(", ", args);
- }
-}
diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/BaseHookWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/BaseHookWriter.cs
deleted file mode 100644
index f9017e7e8d..0000000000
--- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/BaseHookWriter.cs
+++ /dev/null
@@ -1,38 +0,0 @@
-using TUnit.Core.SourceGenerator.Enums;
-using TUnit.Core.SourceGenerator.Models;
-
-namespace TUnit.Core.SourceGenerator.CodeGenerators.Writers.Hooks;
-
-public class BaseHookWriter
-{
- protected static string GetArgs(HooksDataModel model)
- {
- List args = [];
-
- var expectedType = model.HookLevel switch
- {
- "TUnit.Core.HookType.Test" => WellKnownFullyQualifiedClassNames.TestContext,
- "TUnit.Core.HookType.Class" => WellKnownFullyQualifiedClassNames.ClassHookContext,
- "TUnit.Core.HookType.Assembly" => WellKnownFullyQualifiedClassNames.AssemblyHookContext,
- "TUnit.Core.HookType.TestSession" => WellKnownFullyQualifiedClassNames.TestSessionContext,
- "TUnit.Core.HookType.TestDiscovery" when model.HookLocationType == HookLocationType.Before => WellKnownFullyQualifiedClassNames.BeforeTestDiscoveryContext,
- "TUnit.Core.HookType.TestDiscovery" when model.HookLocationType == HookLocationType.After => WellKnownFullyQualifiedClassNames.TestDiscoveryContext,
- _ => throw new ArgumentOutOfRangeException()
- };
-
- foreach (var type in model.ParameterTypes)
- {
- if (type == expectedType.WithGlobalPrefix)
- {
- args.Add("context");
- }
-
- if (type == WellKnownFullyQualifiedClassNames.CancellationToken.WithGlobalPrefix)
- {
- args.Add("cancellationToken");
- }
- }
-
- return string.Join(", ", args);
- }
-}
diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/ClassHooksWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/ClassHooksWriter.cs
deleted file mode 100644
index e9c4a5ff8c..0000000000
--- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/ClassHooksWriter.cs
+++ /dev/null
@@ -1,54 +0,0 @@
-using TUnit.Core.SourceGenerator.CodeGenerators.Helpers;
-using TUnit.Core.SourceGenerator.Enums;
-using TUnit.Core.SourceGenerator.Models;
-
-namespace TUnit.Core.SourceGenerator.CodeGenerators.Writers.Hooks;
-
-public static class ClassHooksWriter
-{
- public static void Execute(ICodeWriter sourceBuilder, HooksDataModel model)
- {
- if (model.HookLocationType == HookLocationType.Before)
- {
- sourceBuilder.Append("new global::TUnit.Core.Hooks.BeforeClassHookMethod");
- }
- else
- {
- sourceBuilder.Append("new global::TUnit.Core.Hooks.AfterClassHookMethod");
- }
-
- sourceBuilder.Append("{");
- sourceBuilder.Append("MethodInfo = ");
- SourceInformationWriter.GenerateMethodInformation(sourceBuilder, model.Context.SemanticModel.Compilation, model.ClassType, model.Method, null, ',');
-
- sourceBuilder.Append($"Body = (context, cancellationToken) => AsyncConvert.Convert(() => {model.FullyQualifiedTypeName}.{model.MethodName}({GetArgs(model)})),");
-
- sourceBuilder.Append($"HookExecutor = {HookExecutorHelper.GetHookExecutor(model.HookExecutor)},");
- sourceBuilder.Append($"Order = {model.Order},");
- sourceBuilder.Append($"RegistrationIndex = global::TUnit.Core.HookRegistrationIndices.GetNext{(model.HookLocationType == HookLocationType.Before ? "Before" : "After")}{(model.IsEveryHook ? "Every" : "")}ClassHookIndex(),");
- sourceBuilder.Append($"""FilePath = @"{model.FilePath}",""");
- sourceBuilder.Append($"LineNumber = {model.LineNumber},");
-
- sourceBuilder.Append("},");
- }
-
- private static string GetArgs(HooksDataModel model)
- {
- List args = [];
-
- foreach (var type in model.ParameterTypes)
- {
- if (type == WellKnownFullyQualifiedClassNames.ClassHookContext.WithGlobalPrefix)
- {
- args.Add("context");
- }
-
- if (type == WellKnownFullyQualifiedClassNames.CancellationToken.WithGlobalPrefix)
- {
- args.Add("cancellationToken");
- }
- }
-
- return string.Join(", ", args);
- }
-}
diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/GlobalTestHooksWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/GlobalTestHooksWriter.cs
deleted file mode 100644
index 365fc2bec3..0000000000
--- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/GlobalTestHooksWriter.cs
+++ /dev/null
@@ -1,102 +0,0 @@
-using TUnit.Core.SourceGenerator.CodeGenerators.Helpers;
-using TUnit.Core.SourceGenerator.Enums;
-using TUnit.Core.SourceGenerator.Models;
-
-namespace TUnit.Core.SourceGenerator.CodeGenerators.Writers.Hooks;
-
-public static class GlobalTestHooksWriter
-{
- public static void Execute(ICodeWriter sourceBuilder, HooksDataModel model)
- {
- sourceBuilder.Append($"new {GetClassType(model.HookLevel, model.HookLocationType)}");
- sourceBuilder.Append("{");
-
- sourceBuilder.Append("MethodInfo = ");
- SourceInformationWriter.GenerateMethodInformation(sourceBuilder, model.Context.SemanticModel.Compilation, model.ClassType, model.Method, null, ',');
-
- sourceBuilder.Append($"Body = (context, cancellationToken) => AsyncConvert.Convert(() => {model.FullyQualifiedTypeName}.{model.MethodName}({GetArgs(model, model.HookLocationType)})),");
-
- sourceBuilder.Append($"HookExecutor = {HookExecutorHelper.GetHookExecutor(model.HookExecutor)},");
- sourceBuilder.Append($"Order = {model.Order},");
- sourceBuilder.Append($"RegistrationIndex = {GetRegistrationIndexMethod(model.HookLevel, model.HookLocationType, model.IsEveryHook)},");
- sourceBuilder.Append($"""FilePath = @"{model.FilePath}",""");
- sourceBuilder.Append($"LineNumber = {model.LineNumber},");
-
- sourceBuilder.Append("},");
- }
-
- private static string GetClassType(string hookType, HookLocationType hookLocationType)
- {
- if (hookLocationType == HookLocationType.Before)
- {
- return hookType switch
- {
- "TUnit.Core.HookType.Test" => "global::TUnit.Core.Hooks.BeforeTestHookMethod",
- "TUnit.Core.HookType.Class" => "global::TUnit.Core.Hooks.BeforeClassHookMethod",
- "TUnit.Core.HookType.Assembly" => "global::TUnit.Core.Hooks.BeforeAssemblyHookMethod",
- "TUnit.Core.HookType.TestSession" => "global::TUnit.Core.Hooks.BeforeTestSessionHookMethod",
- "TUnit.Core.HookType.TestDiscovery" => "global::TUnit.Core.Hooks.BeforeTestDiscoveryHookMethod",
- _ => throw new ArgumentOutOfRangeException(nameof(hookType), hookType, null)
- };
- }
-
- return hookType switch
- {
- "TUnit.Core.HookType.Test" => "global::TUnit.Core.Hooks.AfterTestHookMethod",
- "TUnit.Core.HookType.Class" => "global::TUnit.Core.Hooks.AfterClassHookMethod",
- "TUnit.Core.HookType.Assembly" => "global::TUnit.Core.Hooks.AfterAssemblyHookMethod",
- "TUnit.Core.HookType.TestSession" => "global::TUnit.Core.Hooks.AfterTestSessionHookMethod",
- "TUnit.Core.HookType.TestDiscovery" => "global::TUnit.Core.Hooks.AfterTestDiscoveryHookMethod",
- _ => throw new ArgumentOutOfRangeException(nameof(hookType), hookType, null)
- };
- }
-
- private static string GetArgs(HooksDataModel model, HookLocationType hookLocationType)
- {
- List args = [];
-
- var expectedType = model.HookLevel switch
- {
- "TUnit.Core.HookType.Test" => WellKnownFullyQualifiedClassNames.TestContext,
- "TUnit.Core.HookType.Class" => WellKnownFullyQualifiedClassNames.ClassHookContext,
- "TUnit.Core.HookType.Assembly" => WellKnownFullyQualifiedClassNames.AssemblyHookContext,
- "TUnit.Core.HookType.TestSession" => WellKnownFullyQualifiedClassNames.TestSessionContext,
- "TUnit.Core.HookType.TestDiscovery" when hookLocationType == HookLocationType.Before => WellKnownFullyQualifiedClassNames.BeforeTestDiscoveryContext,
- "TUnit.Core.HookType.TestDiscovery" when hookLocationType == HookLocationType.After => WellKnownFullyQualifiedClassNames.TestDiscoveryContext,
- _ => throw new ArgumentOutOfRangeException()
- };
-
- foreach (var type in model.ParameterTypes)
- {
- if (type == expectedType.WithGlobalPrefix)
- {
- args.Add("context");
- }
-
- if (type == WellKnownFullyQualifiedClassNames.CancellationToken.WithGlobalPrefix)
- {
- args.Add("cancellationToken");
- }
- }
-
- return string.Join(", ", args);
- }
-
- private static string GetRegistrationIndexMethod(string hookType, HookLocationType hookLocationType, bool isEveryHook)
- {
- var hookTypeSimple = hookType switch
- {
- "TUnit.Core.HookType.Test" => "Test",
- "TUnit.Core.HookType.Class" => "Class",
- "TUnit.Core.HookType.Assembly" => "Assembly",
- "TUnit.Core.HookType.TestSession" => "TestSession",
- "TUnit.Core.HookType.TestDiscovery" => "TestDiscovery",
- _ => throw new ArgumentOutOfRangeException(nameof(hookType), hookType, null)
- };
-
- var prefix = hookLocationType == HookLocationType.Before ? "Before" : "After";
- var everyPart = isEveryHook ? "Every" : "";
-
- return $"global::TUnit.Core.HookRegistrationIndices.GetNext{prefix}{everyPart}{hookTypeSimple}HookIndex()";
- }
-}
diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/SourceInformationWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/SourceInformationWriter.cs
index 80fdd3135a..460410e5ea 100644
--- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/SourceInformationWriter.cs
+++ b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/SourceInformationWriter.cs
@@ -1,7 +1,6 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using TUnit.Core.SourceGenerator.Enums;
-using TUnit.Core.SourceGenerator.Extensions;
using TUnit.Core.SourceGenerator.Utilities;
namespace TUnit.Core.SourceGenerator.CodeGenerators.Writers;
diff --git a/TUnit.Core.SourceGenerator/Extensions/AttributeDataExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/AttributeDataExtensions.cs
index 23edbd23f4..68348a2c4b 100644
--- a/TUnit.Core.SourceGenerator/Extensions/AttributeDataExtensions.cs
+++ b/TUnit.Core.SourceGenerator/Extensions/AttributeDataExtensions.cs
@@ -52,24 +52,4 @@ public static bool IsTypedDataSourceAttribute(this AttributeData? attributeData)
return typedInterface?.TypeArguments.FirstOrDefault();
}
-
- public static bool IsNonGlobalHook(this AttributeData attributeData, Compilation compilation)
- {
- // Cache type symbols to avoid repeated GetTypeByMetadataName calls
- var beforeAttribute = compilation.GetTypeByMetadataName(WellKnownFullyQualifiedClassNames.BeforeAttribute.WithoutGlobalPrefix);
- var afterAttribute = compilation.GetTypeByMetadataName(WellKnownFullyQualifiedClassNames.AfterAttribute.WithoutGlobalPrefix);
-
- return SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, beforeAttribute)
- || SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, afterAttribute);
- }
-
- public static bool IsGlobalHook(this AttributeData attributeData, Compilation compilation)
- {
- // Cache type symbols to avoid repeated GetTypeByMetadataName calls
- var beforeEveryAttribute = compilation.GetTypeByMetadataName(WellKnownFullyQualifiedClassNames.BeforeEveryAttribute.WithoutGlobalPrefix);
- var afterEveryAttribute = compilation.GetTypeByMetadataName(WellKnownFullyQualifiedClassNames.AfterEveryAttribute.WithoutGlobalPrefix);
-
- return SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, beforeEveryAttribute)
- || SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, afterEveryAttribute);
- }
}
diff --git a/TUnit.Core.SourceGenerator/Extensions/CompilationExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/CompilationExtensions.cs
deleted file mode 100644
index c561359edc..0000000000
--- a/TUnit.Core.SourceGenerator/Extensions/CompilationExtensions.cs
+++ /dev/null
@@ -1,42 +0,0 @@
-using Microsoft.CodeAnalysis;
-using Microsoft.CodeAnalysis.CSharp;
-using TUnit.Core.SourceGenerator.Extensions;
-
-namespace TUnit.Analyzers.Extensions;
-
-public static class CompilationExtensions
-{
- public static bool HasImplicitConversionOrGenericParameter(this Compilation compilation, ITypeSymbol? argumentType,
- ITypeSymbol? parameterType)
- {
- if (parameterType?.IsGenericDefinition() == false)
- {
- if (argumentType is null)
- {
- return false;
- }
-
- var conversion = compilation.ClassifyConversion(argumentType, parameterType);
-
- return conversion.IsImplicit || conversion.IsNumeric;
- }
-
- if (parameterType is INamedTypeSymbol { IsGenericType: true, TypeArguments: [{ TypeKind: TypeKind.TypeParameter }] } namedType)
- {
- // `IEnumerable<>`
- if (argumentType is IArrayTypeSymbol { ElementType: { } elementType })
- {
- var specializedSuper = namedType.OriginalDefinition.Construct(elementType);
- return compilation.HasImplicitConversion(argumentType, specializedSuper);
- }
-
- if (argumentType is INamedTypeSymbol { IsGenericType: true, TypeArguments: [{ } genericArgument] })
- {
- var specializedSuper = namedType.OriginalDefinition.Construct(genericArgument);
- return compilation.HasImplicitConversion(argumentType, specializedSuper);
- }
- }
-
- return compilation.HasImplicitConversion(argumentType?.OriginalDefinition, parameterType?.OriginalDefinition);
- }
-}
diff --git a/TUnit.Core.SourceGenerator/Extensions/MethodExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/MethodExtensions.cs
index c82b1199cd..d28002c83e 100644
--- a/TUnit.Core.SourceGenerator/Extensions/MethodExtensions.cs
+++ b/TUnit.Core.SourceGenerator/Extensions/MethodExtensions.cs
@@ -4,7 +4,13 @@ namespace TUnit.Core.SourceGenerator.Extensions;
public static class MethodExtensions
{
- public static AttributeData? GetTestAttribute(this IMethodSymbol methodSymbol)
+ public static AttributeData GetRequiredTestAttribute(this IMethodSymbol methodSymbol)
+ {
+ return GetTestAttribute(methodSymbol) ??
+ throw new ArgumentException($"No test attribute found on {methodSymbol.ContainingType.Name}.{methodSymbol.Name}");
+ }
+
+ private static AttributeData? GetTestAttribute(IMethodSymbol methodSymbol)
{
var attributes = methodSymbol.GetAttributes();
@@ -17,20 +23,4 @@ public static class MethodExtensions
.FirstOrDefault(x => x.AttributeClass?.BaseType?.GloballyQualified()
== WellKnownFullyQualifiedClassNames.BaseTestAttribute.WithGlobalPrefix);
}
-
- public static AttributeData GetRequiredTestAttribute(this IMethodSymbol methodSymbol)
- {
- return GetTestAttribute(methodSymbol) ??
- throw new ArgumentException($"No test attribute found on {methodSymbol.ContainingType.Name}.{methodSymbol.Name}");
- }
-
- public static bool IsTest(this IMethodSymbol methodSymbol)
- {
- return methodSymbol.GetTestAttribute() != null;
- }
-
- public static bool IsHook(this IMethodSymbol methodSymbol, Compilation compilation)
- {
- return methodSymbol.GetAttributes().Any(x => x.IsNonGlobalHook(compilation) || x.IsGlobalHook(compilation));
- }
}
diff --git a/TUnit.Core.SourceGenerator/Extensions/SymbolExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/SymbolExtensions.cs
index 971586cc2c..c9d419ff67 100644
--- a/TUnit.Core.SourceGenerator/Extensions/SymbolExtensions.cs
+++ b/TUnit.Core.SourceGenerator/Extensions/SymbolExtensions.cs
@@ -1,5 +1,4 @@
-using System.Collections.Generic;
-using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis;
namespace TUnit.Core.SourceGenerator.Extensions;
@@ -28,7 +27,7 @@ public static bool IsConst(this ISymbol? symbol, out object? constantValue)
constantValue = null;
return false;
}
-
+
///
/// Creates an IEqualityComparer for tuples that uses SymbolEqualityComparer for symbol comparison
///
@@ -36,21 +35,21 @@ public static bool IsConst(this ISymbol? symbol, out object? constantValue)
{
return new TupleSymbolComparer(comparer);
}
-
+
private class TupleSymbolComparer : IEqualityComparer<(INamedTypeSymbol, string)>
{
private readonly IEqualityComparer _symbolComparer;
-
+
public TupleSymbolComparer(IEqualityComparer symbolComparer)
{
_symbolComparer = symbolComparer;
}
-
+
public bool Equals((INamedTypeSymbol, string) x, (INamedTypeSymbol, string) y)
{
return _symbolComparer.Equals(x.Item1, y.Item1) && x.Item2 == y.Item2;
}
-
+
public int GetHashCode((INamedTypeSymbol, string) obj)
{
var hash1 = _symbolComparer.GetHashCode(obj.Item1);
diff --git a/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs
index d9c530c89d..ba44c376b4 100644
--- a/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs
+++ b/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs
@@ -1,33 +1,11 @@
-using System.Collections.Immutable;
-using System.Diagnostics.CodeAnalysis;
+using System.Diagnostics.CodeAnalysis;
using System.Text;
using Microsoft.CodeAnalysis;
-using TUnit.Analyzers.Extensions;
-using TUnit.Core.SourceGenerator.CodeGenerators.Helpers;
namespace TUnit.Core.SourceGenerator.Extensions;
public static class TypeExtensions
{
- private static readonly Dictionary ReservedTypeKeywords = new()
- {
- { "System.Boolean", "bool" },
- { "System.Byte", "byte" },
- { "System.SByte", "sbyte" },
- { "System.Char", "char" },
- { "System.Decimal", "decimal" },
- { "System.Double", "double" },
- { "System.Single", "float" },
- { "System.Int32", "int" },
- { "System.UInt32", "uint" },
- { "System.Int64", "long" },
- { "System.UInt64", "ulong" },
- { "System.Int16", "short" },
- { "System.UInt16", "ushort" },
- { "System.Object", "object" },
- { "System.String", "string" }
- };
-
public static string GetMetadataName(this Type type)
{
return $"{type.Namespace}.{type.Name}";
diff --git a/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs b/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs
index e6675284c4..6d2e909489 100644
--- a/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs
+++ b/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs
@@ -1,7 +1,6 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using TUnit.Core.SourceGenerator.CodeGenerators.Helpers;
-using TUnit.Core.SourceGenerator.CodeGenerators.Writers;
using TUnit.Core.SourceGenerator.Extensions;
using TUnit.Core.SourceGenerator.Models;
using TUnit.Core.SourceGenerator.Models.Extracted;
diff --git a/TUnit.Core.SourceGenerator/Helpers/GenericTypeInference.cs b/TUnit.Core.SourceGenerator/Helpers/GenericTypeInference.cs
deleted file mode 100644
index 091a767e0d..0000000000
--- a/TUnit.Core.SourceGenerator/Helpers/GenericTypeInference.cs
+++ /dev/null
@@ -1,359 +0,0 @@
-using System.Collections.Immutable;
-using Microsoft.CodeAnalysis;
-using TUnit.Core.SourceGenerator.CodeGenerators.Helpers;
-using TUnit.Core.SourceGenerator.Extensions;
-
-namespace TUnit.Core.SourceGenerator.Helpers;
-
-///
-/// Provides generic type inference capabilities for test methods with data sources
-///
-internal static class GenericTypeInference
-{
- private static ImmutableArray? TryInferFromTypedDataSources(IMethodSymbol method, ImmutableArray attributes)
- {
- foreach (var attribute in attributes)
- {
- if (attribute.AttributeClass == null)
- {
- continue;
- }
-
- // Check if this is a typed data source (inherits from AsyncDataSourceGeneratorAttribute or DataSourceGeneratorAttribute)
- var baseType = GetTypedDataSourceBase(attribute.AttributeClass);
- if (baseType is { TypeArguments.Length: > 0 })
- {
- // For single type parameter methods, use the first type argument
- if (method.TypeParameters.Length == 1)
- {
- return ImmutableArray.Create(baseType.TypeArguments[0]);
- }
-
- // For multiple type parameters, match by parameter position if possible
- if (baseType.TypeArguments.Length >= method.TypeParameters.Length)
- {
- return baseType.TypeArguments.Take(method.TypeParameters.Length).ToImmutableArray();
- }
- }
- }
-
- return null;
- }
-
- private static INamedTypeSymbol? GetTypedDataSourceBase(INamedTypeSymbol attributeClass)
- {
- var current = attributeClass;
- while (current != null)
- {
- // Check if it's a generic base class
- if (current.IsGenericType)
- {
- var name = current.Name;
- if (name.Contains("DataSourceGeneratorAttribute") ||
- name.Contains("AsyncDataSourceGeneratorAttribute"))
- {
- return current;
- }
- }
-
- current = current.BaseType;
- }
-
- return null;
- }
-
- private static ImmutableArray? TryInferFromTypeInferringAttributes(IMethodSymbol method)
- {
- var inferredTypes = new List();
-
- // Look at each parameter to see if it has attributes that implement IInfersType
- foreach (var parameter in method.Parameters)
- {
- if (parameter.Type is ITypeParameterSymbol typeParam)
- {
- // Check if this parameter has attributes that implement IInfersType
- foreach (var attr in parameter.GetAttributes())
- {
- if (attr.AttributeClass != null)
- {
- // Look for IInfersType in the attribute's interfaces
- var infersTypeInterface = attr.AttributeClass.AllInterfaces
- .FirstOrDefault(i => i.GloballyQualifiedNonGeneric() == "global::TUnit.Core.Interfaces.IInfersType" &&
- i.IsGenericType &&
- i.TypeArguments.Length == 1);
-
- if (infersTypeInterface != null)
- {
- // Get the type argument from IInfersType
- var inferredType = infersTypeInterface.TypeArguments[0];
-
- // Find the index of this type parameter
- var typeParamIndex = -1;
- for (var i = 0; i < method.TypeParameters.Length; i++)
- {
- if (method.TypeParameters[i].Name == typeParam.Name)
- {
- typeParamIndex = i;
- break;
- }
- }
-
- if (typeParamIndex >= 0)
- {
- // Make sure we have enough slots
- while (inferredTypes.Count <= typeParamIndex)
- {
- inferredTypes.Add(null!);
- }
- inferredTypes[typeParamIndex] = inferredType;
- }
- }
- }
- }
- }
- }
-
- // Remove any null entries and check if we have all types
- inferredTypes.RemoveAll(t => t == null);
-
- return inferredTypes.Count == method.TypeParameters.Length
- ? inferredTypes.ToImmutableArray()
- : null;
- }
-
- private static ITypeSymbol? InferTypeFromValue(TypedConstant value)
- {
- if (value.IsNull)
- {
- return null;
- }
-
- // The type of the constant value tells us what T should be
- return value.Type;
- }
-
- private static ImmutableArray? TryInferFromMethodDataSource(IMethodSymbol testMethod, ImmutableArray attributes)
- {
- var methodDataSourceAttributes = attributes
- .Where(a => a.AttributeClass?.Name == "MethodDataSourceAttribute")
- .ToList();
-
- if (!methodDataSourceAttributes.Any())
- {
- return null;
- }
-
- foreach (var attr in methodDataSourceAttributes)
- {
- if (attr.ConstructorArguments.Length > 0 && attr.ConstructorArguments[0].Value is string methodName)
- {
- // Find the data source method
- var dataSourceMethod = testMethod.ContainingType.GetMembers(methodName)
- .OfType()
- .FirstOrDefault();
-
- if (dataSourceMethod is { ReturnType: INamedTypeSymbol { IsGenericType: true, TypeArguments.Length: > 0 } namedType })
- // Analyze the return type to extract generic types
- // Handle IEnumerable>
- {
- var funcType = namedType.TypeArguments[0];
- if (funcType is INamedTypeSymbol { Name: "Func", TypeArguments.Length: > 0 } funcNamedType)
- {
- var tupleType = funcNamedType.TypeArguments[0];
- if (tupleType is INamedTypeSymbol { IsTupleType: true } tupleNamedType)
- {
- // Extract types from tuple elements
- var inferredTypes = InferTypesFromTupleElements(testMethod, tupleNamedType);
- if (inferredTypes != null)
- {
- return inferredTypes;
- }
- }
- }
- }
- }
- }
-
- return null;
- }
-
- private static ImmutableArray? InferTypesFromTupleElements(IMethodSymbol testMethod, INamedTypeSymbol tupleType)
- {
- var inferredTypes = new ITypeSymbol[testMethod.TypeParameters.Length];
- var tupleElements = tupleType.TupleElements;
-
- // Map tuple elements to method parameters
- for (var i = 0; i < testMethod.Parameters.Length && i < tupleElements.Length; i++)
- {
- var parameter = testMethod.Parameters[i];
- var tupleElement = tupleElements[i];
-
- if (parameter.Type is ITypeParameterSymbol typeParam)
- {
- // Find the index of this type parameter
- var typeParamIndex = -1;
- for (var j = 0; j < testMethod.TypeParameters.Length; j++)
- {
- if (testMethod.TypeParameters[j].Name == typeParam.Name)
- {
- typeParamIndex = j;
- break;
- }
- }
-
- if (typeParamIndex >= 0)
- {
- // For generic types like IEnumerable, extract T
- var elementType = tupleElement.Type;
- if (elementType is INamedTypeSymbol { IsGenericType: true, TypeArguments.Length: > 0 } namedElementType)
- {
- // For IEnumerable, we want int
- inferredTypes[typeParamIndex] = namedElementType.TypeArguments[0];
- }
- else
- {
- // For direct types
- inferredTypes[typeParamIndex] = elementType;
- }
- }
- }
- else if (parameter.Type is INamedTypeSymbol { IsGenericType: true } paramNamedType)
- {
- // Handle complex generic parameters like Func
- // This is more complex and would need deeper analysis
- var tupleElementType = tupleElement.Type;
- if (tupleElementType is INamedTypeSymbol { Name: "Func" } funcType)
- {
- // Match type arguments between parameter type and tuple element type
- for (var j = 0; j < funcType.TypeArguments.Length && j < paramNamedType.TypeArguments.Length; j++)
- {
- var paramTypeArg = paramNamedType.TypeArguments[j];
- if (paramTypeArg is ITypeParameterSymbol funcTypeParam)
- {
- var typeParamIndex = -1;
- for (var k = 0; k < testMethod.TypeParameters.Length; k++)
- {
- if (testMethod.TypeParameters[k].Name == funcTypeParam.Name)
- {
- typeParamIndex = k;
- break;
- }
- }
-
- if (typeParamIndex >= 0)
- {
- inferredTypes[typeParamIndex] = funcType.TypeArguments[j];
- }
- }
- }
- }
- }
- }
-
- // Check if we have all required types
- if (inferredTypes.All(t => t != null))
- {
- return inferredTypes.ToImmutableArray();
- }
-
- return null;
- }
-
- ///
- /// Gets all unique generic type combinations for a method based on its data sources
- ///
- public static ImmutableArray> GetAllGenericTypeCombinations(
- IMethodSymbol method,
- ImmutableArray attributes)
- {
- var combinations = new List>();
-
- // For Arguments attributes, each one might produce a different type combination
- var argumentsAttributes = attributes
- .Where(a => a.AttributeClass?.Name == "ArgumentsAttribute")
- .ToList();
-
- foreach (var args in argumentsAttributes)
- {
- var types = InferTypesFromSingleArguments(method, args);
- if (types != null && !combinations.Any(c => TypeArraysEqual(c, types.Value)))
- {
- combinations.Add(types.Value);
- }
- }
-
- // For typed data sources, we typically get one type combination
- var typedSourceTypes = TryInferFromTypedDataSources(method, attributes);
- if (typedSourceTypes != null && !combinations.Any(c => TypeArraysEqual(c, typedSourceTypes.Value)))
- {
- combinations.Add(typedSourceTypes.Value);
- }
-
- // For parameter attributes that implement IInfersType
- var inferredTypes = TryInferFromTypeInferringAttributes(method);
- if (inferredTypes != null && !combinations.Any(c => TypeArraysEqual(c, inferredTypes.Value)))
- {
- combinations.Add(inferredTypes.Value);
- }
-
- // For MethodDataSource attributes
- var methodDataSourceTypes = TryInferFromMethodDataSource(method, attributes);
- if (methodDataSourceTypes != null && !combinations.Any(c => TypeArraysEqual(c, methodDataSourceTypes.Value)))
- {
- combinations.Add(methodDataSourceTypes.Value);
- }
-
- return combinations.ToImmutableArray();
- }
-
- private static ImmutableArray? InferTypesFromSingleArguments(IMethodSymbol method, AttributeData args)
- {
- if (!method.IsGenericMethod || args.ConstructorArguments.Length == 0)
- {
- return null;
- }
-
- var inferredTypes = new List();
-
- for (var i = 0; i < method.TypeParameters.Length && i < method.Parameters.Length; i++)
- {
- var parameter = method.Parameters[i];
-
- if (parameter.Type is ITypeParameterSymbol)
- {
- if (i < args.ConstructorArguments.Length)
- {
- var argValue = args.ConstructorArguments[i];
- var inferredType = InferTypeFromValue(argValue);
-
- if (inferredType != null)
- {
- inferredTypes.Add(inferredType);
- }
- }
- }
- }
-
- return inferredTypes.Count == method.TypeParameters.Length
- ? inferredTypes.ToImmutableArray()
- : null;
- }
-
- private static bool TypeArraysEqual(ImmutableArray a, ImmutableArray b)
- {
- if (a.Length != b.Length)
- {
- return false;
- }
-
- for (var i = 0; i < a.Length; i++)
- {
- if (!SymbolEqualityComparer.Default.Equals(a[i], b[i]))
- {
- return false;
- }
- }
-
- return true;
- }
-}
diff --git a/TUnit.Core.SourceGenerator/Models/Extracted/PropertyInjectionModel.cs b/TUnit.Core.SourceGenerator/Models/Extracted/PropertyInjectionModel.cs
index 28fe2326b7..0494c65d1d 100644
--- a/TUnit.Core.SourceGenerator/Models/Extracted/PropertyInjectionModel.cs
+++ b/TUnit.Core.SourceGenerator/Models/Extracted/PropertyInjectionModel.cs
@@ -1,5 +1,3 @@
-using TUnit.Core.SourceGenerator.Models;
-
namespace TUnit.Core.SourceGenerator.Models.Extracted;
///
diff --git a/TUnit.Core.SourceGenerator/TUnit.Core.SourceGenerator.csproj b/TUnit.Core.SourceGenerator/TUnit.Core.SourceGenerator.csproj
index 4340cfc618..f6efb79d95 100644
--- a/TUnit.Core.SourceGenerator/TUnit.Core.SourceGenerator.csproj
+++ b/TUnit.Core.SourceGenerator/TUnit.Core.SourceGenerator.csproj
@@ -24,4 +24,12 @@
+
+
+
+
+
+
+
+
diff --git a/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs b/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs
index fd2cc9e652..db6b93fda3 100644
--- a/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs
+++ b/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs
@@ -1,6 +1,5 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
-using TUnit.Core.SourceGenerator.CodeGenerators;
using TUnit.Core.SourceGenerator.Extensions;
namespace TUnit.Core.SourceGenerator.Utilities;