diff --git a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs index 6cb4b98438..afac19def2 100644 --- a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs @@ -233,9 +233,6 @@ private static void GenerateTestMethodSource(SourceProductionContext context, Te return; } - // Get compilation from semantic model instead of parameter - var compilation = testMethod.Context.Value.SemanticModel.Compilation; - var writer = new CodeWriter(); GenerateFileHeader(writer); GenerateTestMetadata(writer, testMethod); @@ -274,10 +271,7 @@ private static void GenerateFileHeader(CodeWriter writer) private static void GenerateTestMetadata(CodeWriter writer, TestMethodMetadata testMethod) { - var compilation = testMethod.Context!.Value.SemanticModel.Compilation; - var className = testMethod.TypeSymbol.GloballyQualified(); - var methodName = testMethod.MethodSymbol.Name; // Generate unique class name using same pattern as filename (without .g.cs extension) var uniqueClassName = FileNameHelper.GetDeterministicFileNameForMethod(testMethod.TypeSymbol, testMethod.MethodSymbol) @@ -319,7 +313,7 @@ private static void GenerateTestMetadata(CodeWriter writer, TestMethodMetadata t var hasMethodDataSourceForGenericType = testMethod is { IsGenericType: true, IsGenericMethod: false } && testMethod.MethodAttributes .Any(a => a.AttributeClass?.Name == "MethodDataSourceAttribute" && - InferClassTypesFromMethodDataSource(compilation, testMethod, a) != null); + InferClassTypesFromMethodDataSource(testMethod, a) != null); // Check for class-level data sources that could help resolve generic type arguments var hasClassDataSources = testMethod.IsGenericType && testMethod.TypeSymbol.GetAttributesIncludingBaseTypes() @@ -327,7 +321,7 @@ private static void GenerateTestMetadata(CodeWriter writer, TestMethodMetadata t if (hasTypedDataSource || hasGenerateGenericTest || testMethod.IsGenericMethod || hasClassArguments || hasTypedDataSourceForGenericType || hasMethodArgumentsForGenericType || hasMethodDataSourceForGenericType || hasClassDataSources) { - GenerateGenericTestWithConcreteTypes(writer, testMethod, className, uniqueClassName); + GenerateGenericTestWithConcreteTypes(writer, testMethod, className); } else { @@ -356,139 +350,6 @@ private static void GenerateTestMetadata(CodeWriter writer, TestMethodMetadata t GenerateModuleInitializer(writer, testMethod, uniqueClassName); } - private static void GenerateSpecificGenericInstantiation( - CodeWriter writer, - TestMethodMetadata testMethod, - string className, - string combinationGuid, - ImmutableArray typeArguments) - { - var compilation = testMethod.Context!.Value.SemanticModel.Compilation; - var methodName = testMethod.MethodSymbol.Name; - var typeArgsString = string.Join(", ", typeArguments.Select(t => t.GloballyQualified())); - var instantiatedMethodName = $"{methodName}<{typeArgsString}>"; - - var concreteTestMethod = new TestMethodMetadata - { - MethodSymbol = testMethod.MethodSymbol, - TypeSymbol = testMethod.TypeSymbol, - FilePath = testMethod.FilePath, - LineNumber = testMethod.LineNumber, - TestAttribute = testMethod.TestAttribute, - Context = testMethod.Context, - MethodSyntax = testMethod.MethodSyntax, - IsGenericType = testMethod.IsGenericType, - IsGenericMethod = false, // We're creating a concrete instantiation - MethodAttributes = testMethod.MethodAttributes - }; - - writer.AppendLine($"// Generated instantiation for {instantiatedMethodName}"); - writer.AppendLine("{"); - writer.Indent(); - - writer.AppendLine($"var metadata = new global::TUnit.Core.TestMetadata<{className}>"); - writer.AppendLine("{"); - writer.Indent(); - - writer.AppendLine($"TestName = \"{instantiatedMethodName}\","); - writer.AppendLine($"TestClassType = {GenerateTypeReference(testMethod.TypeSymbol, testMethod.IsGenericType)},"); - writer.AppendLine($"TestMethodName = \"{methodName}\","); - writer.AppendLine($"GenericMethodTypeArguments = new global::System.Type[] {{ {string.Join(", ", typeArguments.Select(t => $"typeof({t.GloballyQualified()})"))}}},"); - - GenerateMetadata(writer, concreteTestMethod); - - if (testMethod.IsGenericType) - { - GenerateGenericTypeInfo(writer, testMethod.TypeSymbol); - } - - GenerateAotFriendlyInvokers(writer, testMethod, className, typeArguments); - - writer.AppendLine($"FilePath = @\"{(testMethod.FilePath ?? "").Replace("\\", "\\\\")}\","); - writer.AppendLine($"LineNumber = {testMethod.LineNumber},"); - - writer.Unindent(); - writer.AppendLine("};"); - - writer.AppendLine("metadata.TestSessionId = testSessionId;"); - writer.AppendLine("yield return metadata;"); - - writer.Unindent(); - writer.AppendLine("}"); - writer.AppendLine(); - } - - private static void GenerateAotFriendlyInvokers( - CodeWriter writer, - TestMethodMetadata testMethod, - string className, - ImmutableArray typeArguments) - { - var methodName = testMethod.MethodSymbol.Name; - var typeArgsString = string.Join(", ", typeArguments.Select(t => t.GloballyQualified())); - var hasCancellationToken = testMethod.MethodSymbol.Parameters.Any(p => - p.Type.Name == "CancellationToken" && - p.Type.ContainingNamespace?.ToString() == "System.Threading"); - - writer.AppendLine("InstanceFactory = static (typeArgs, args) =>"); - writer.AppendLine("{"); - writer.Indent(); - - writer.AppendLine($"return new {className}();"); - - writer.Unindent(); - writer.AppendLine("},"); - - writer.AppendLine("InvokeTypedTest = static (instance, args, cancellationToken) =>"); - writer.AppendLine("{"); - writer.Indent(); - - // Wrap entire lambda body in try-catch to handle synchronous exceptions - writer.AppendLine("try"); - writer.AppendLine("{"); - writer.Indent(); - - // Generate direct method call with specific types (no MakeGenericMethod) - writer.AppendLine($"var typedInstance = ({className})instance;"); - - writer.AppendLine("var methodArgs = new object?[args.Length" + (hasCancellationToken ? " + 1" : "") + "];"); - writer.AppendLine("global::System.Array.Copy(args, methodArgs, args.Length);"); - - if (hasCancellationToken) - { - writer.AppendLine("methodArgs[args.Length] = cancellationToken;"); - } - - var parameterCasts = new List(); - for (var i = 0; i < testMethod.MethodSymbol.Parameters.Length; i++) - { - var param = testMethod.MethodSymbol.Parameters[i]; - if (param.Type.Name == "CancellationToken") - { - parameterCasts.Add("cancellationToken"); - } - else - { - var paramType = ReplaceTypeParametersWithConcreteTypes(param.Type, testMethod.MethodSymbol.TypeParameters, typeArguments); - parameterCasts.Add($"({paramType.GloballyQualified()})methodArgs[{i}]!"); - } - } - - writer.AppendLine($"return global::TUnit.Core.AsyncConvert.Convert(() => typedInstance.{methodName}<{typeArgsString}>({string.Join(", ", parameterCasts)}));"); - - writer.Unindent(); - writer.AppendLine("}"); - writer.AppendLine("catch (global::System.Exception ex)"); - writer.AppendLine("{"); - writer.Indent(); - writer.AppendLine("return new global::System.Threading.Tasks.ValueTask(global::System.Threading.Tasks.Task.FromException(ex));"); - writer.Unindent(); - writer.AppendLine("}"); - - writer.Unindent(); - writer.AppendLine("},"); - } - private static ITypeSymbol ReplaceTypeParametersWithConcreteTypes( ITypeSymbol type, ImmutableArray typeParameters, @@ -585,8 +446,7 @@ private static void GenerateMetadata(CodeWriter writer, TestMethodMetadata testM var compilation = testMethod.Context!.Value.SemanticModel.Compilation; var methodSymbol = testMethod.MethodSymbol; - - GenerateDependencies(writer, compilation, methodSymbol); + GenerateDependencies(writer, methodSymbol); writer.AppendLine("AttributeFactory = static () =>"); writer.AppendLine("["); @@ -631,8 +491,7 @@ private static void GenerateMetadataForConcreteInstantiation(CodeWriter writer, var compilation = testMethod.Context!.Value.SemanticModel.Compilation; var methodSymbol = testMethod.MethodSymbol; - - GenerateDependencies(writer, compilation, methodSymbol); + GenerateDependencies(writer, methodSymbol); writer.AppendLine("AttributeFactory = static () =>"); writer.AppendLine("["); @@ -674,10 +533,8 @@ private static void GenerateMetadataForConcreteInstantiation(CodeWriter writer, // Method metadata writer.Append("MethodMetadata = "); SourceInformationWriter.GenerateMethodInformation(writer, compilation, testMethod.TypeSymbol, testMethod.MethodSymbol, null, ','); - } - private static void GenerateDataSources(CodeWriter writer, TestMethodMetadata testMethod) { var compilation = testMethod.Context!.Value.SemanticModel.Compilation; @@ -2073,7 +1930,6 @@ private static void GenerateConcreteTestInvoker(CodeWriter writer, TestMethodMet writer.AppendLine("},"); } - private static void GenerateEnumerateTestDescriptors(CodeWriter writer, TestMethodMetadata testMethod, string className) { var methodName = testMethod.MethodSymbol.Name; @@ -2398,23 +2254,6 @@ private static void GenerateModuleInitializer(CodeWriter writer, TestMethodMetad writer.AppendLine("}"); } - private static bool IsAsyncMethod(IMethodSymbol method) - { - var returnType = method.ReturnType; - - var returnTypeName = returnType.ToDisplayString(); - return returnTypeName.StartsWith("System.Threading.Tasks.Task") || - returnTypeName.StartsWith("System.Threading.Tasks.ValueTask") || - returnTypeName.StartsWith("Task<") || - returnTypeName.StartsWith("ValueTask<"); - } - - private static bool ReturnsValueTask(IMethodSymbol method) - { - var returnTypeName = method.ReturnType.ToDisplayString(); - return returnTypeName.StartsWith("System.Threading.Tasks.ValueTask"); - } - private enum TestReturnPattern { Void, // void methods @@ -2475,7 +2314,7 @@ private static void GenerateReturnHandling( } } - private static void GenerateDependencies(CodeWriter writer, Compilation compilation, IMethodSymbol methodSymbol) + private static void GenerateDependencies(CodeWriter writer, IMethodSymbol methodSymbol) { var dependsOnAttributes = methodSymbol.GetAttributes() .Concat(methodSymbol.ContainingType.GetAttributes()) @@ -2538,7 +2377,7 @@ private static void GenerateTestDependency(CodeWriter writer, AttributeData attr if (arg.Type?.Name == "String") { var testName = arg.Value?.ToString() ?? ""; - + if (genericTypeArgument != null) { // DependsOnAttribute(string testName) - dependency on specific test in class T @@ -2575,7 +2414,7 @@ private static void GenerateTestDependency(CodeWriter writer, AttributeData attr if (firstArg.Type?.Name == "String" && secondArg.Type is IArrayTypeSymbol) { var testName = firstArg.Value?.ToString() ?? ""; - + if (genericTypeArgument != null) { // DependsOnAttribute(string testName, Type[] parameterTypes) - dependency on specific test with parameters in class T @@ -2680,70 +2519,6 @@ private static bool GetProceedOnFailureValue(AttributeData attributeData) return false; } - private static string GetDefaultValueString(IParameterSymbol parameter) - { - if (!parameter.HasExplicitDefaultValue) - { - return $"default({parameter.Type.GloballyQualified()})"; - } - - var defaultValue = parameter.ExplicitDefaultValue; - if (defaultValue == null) - { - return "null"; - } - - var type = parameter.Type; - - // Handle string - if (type.SpecialType == SpecialType.System_String) - { - return $"\"{defaultValue.ToString().Replace("\\", "\\\\").Replace("\"", "\\\"")}\""; - } - - // Handle char - if (type.SpecialType == SpecialType.System_Char) - { - return $"'{defaultValue}'"; - } - - // Handle bool - if (type.SpecialType == SpecialType.System_Boolean) - { - return defaultValue.ToString().ToLowerInvariant(); - } - - // Handle numeric types with proper suffixes - if (type.SpecialType == SpecialType.System_Single) - { - return $"{defaultValue}f"; - } - if (type.SpecialType == SpecialType.System_Double) - { - return $"{defaultValue}d"; - } - if (type.SpecialType == SpecialType.System_Decimal) - { - return $"{defaultValue}m"; - } - if (type.SpecialType == SpecialType.System_Int64) - { - return $"{defaultValue}L"; - } - if (type.SpecialType == SpecialType.System_UInt32) - { - return $"{defaultValue}u"; - } - if (type.SpecialType == SpecialType.System_UInt64) - { - return $"{defaultValue}ul"; - } - - // Default for other types - return defaultValue.ToString(); - } - - private static bool IsMethodHiding(IMethodSymbol derivedMethod, IMethodSymbol baseMethod) { // Must have same name @@ -3158,11 +2933,6 @@ private static void GenerateGenericParameterConstraints(CodeWriter writer, IType writer.AppendLine("},"); } - private static bool IsGenericTypeParameter(ITypeSymbol type) - { - return type.TypeKind == TypeKind.TypeParameter; - } - private static bool ContainsGenericTypeParameter(ITypeSymbol type) { if (type.TypeKind == TypeKind.TypeParameter) @@ -3186,8 +2956,7 @@ private static bool ContainsGenericTypeParameter(ITypeSymbol type) private static void GenerateGenericTestWithConcreteTypes( CodeWriter writer, TestMethodMetadata testMethod, - string className, - string combinationGuid) + string className) { var compilation = testMethod.Context!.Value.SemanticModel.Compilation; var methodName = testMethod.MethodSymbol.Name; @@ -3488,7 +3257,7 @@ private static void GenerateGenericTestWithConcreteTypes( foreach (var mdsAttr in methodDataSourceAttributes) { // Try to infer types from the method data source - var inferredTypes = InferClassTypesFromMethodDataSource(compilation, testMethod, mdsAttr); + var inferredTypes = InferClassTypesFromMethodDataSource(testMethod, mdsAttr); if (inferredTypes is { Length: > 0 }) { var typeKey = BuildTypeKey(inferredTypes); @@ -3542,7 +3311,7 @@ private static void GenerateGenericTestWithConcreteTypes( foreach (var mdsAttr in methodDataSourceAttributes) { // Try to infer types from the method data source - var inferredTypes = InferTypesFromMethodDataSource(compilation, testMethod, mdsAttr); + var inferredTypes = InferTypesFromMethodDataSource(testMethod, mdsAttr); if (inferredTypes is { Length: > 0 }) { var typeKey = BuildTypeKey(inferredTypes); @@ -3786,76 +3555,6 @@ private static void GenerateGenericTestWithConcreteTypes( writer.AppendLine("yield return genericMetadata;"); } - private static void ProcessGenerateGenericTestAttribute( - AttributeData genAttr, - TestMethodMetadata testMethod, - string className, - CodeWriter writer, - HashSet processedTypeCombinations, - bool isClassLevel) - { - // Extract type arguments from the attribute - if (genAttr.ConstructorArguments.Length == 0) - { - return; - } - - var typeArgs = new List(); - foreach (var arg in genAttr.ConstructorArguments) - { - if (arg is { Kind: TypedConstantKind.Type, Value: ITypeSymbol typeSymbol }) - { - typeArgs.Add(typeSymbol); - } - else if (arg.Kind == TypedConstantKind.Array) - { - foreach (var arrayElement in arg.Values) - { - if (arrayElement is { Kind: TypedConstantKind.Type, Value: ITypeSymbol arrayTypeSymbol }) - { - typeArgs.Add(arrayTypeSymbol); - } - } - } - } - - if (typeArgs.Count == 0) - { - return; - } - - var inferredTypes = typeArgs.ToArray(); - var typeKey = BuildTypeKey(inferredTypes); - - // Skip if we've already processed this type combination - if (!processedTypeCombinations.Add(typeKey)) - { - return; - } - - // Validate constraints based on whether this is a class-level or method-level attribute - bool constraintsValid; - if (isClassLevel) - { - // For class-level [GenerateGenericTest], validate against class type constraints - constraintsValid = ValidateClassTypeConstraints(testMethod.TypeSymbol, inferredTypes); - } - else - { - // For method-level [GenerateGenericTest], validate against method type constraints - constraintsValid = ValidateTypeConstraints(testMethod.MethodSymbol, inferredTypes); - } - - if (constraintsValid) - { - // Generate a concrete instantiation for this type combination - // Use the same key format as runtime: FullName ?? Name - writer.AppendLine($"[{string.Join(" + \",\" + ", inferredTypes.Select(FormatTypeForRuntimeName))}] = "); - GenerateConcreteTestMetadata(writer, testMethod, className, inferredTypes); - writer.AppendLine(","); - } - } - private static List ExtractTypeArgumentSets(List attributes) { var result = new List(); @@ -4490,7 +4189,7 @@ private static void MapGenericTypeArguments(ITypeSymbol paramType, ITypeSymbol a return null; } - private static ITypeSymbol[]? InferTypesFromMethodDataSource(Compilation compilation, TestMethodMetadata testMethod, AttributeData mdsAttr) + private static ITypeSymbol[]? InferTypesFromMethodDataSource(TestMethodMetadata testMethod, AttributeData mdsAttr) { if (mdsAttr.ConstructorArguments.Length == 0) { @@ -4569,7 +4268,7 @@ private static void MapGenericTypeArguments(ITypeSymbol paramType, ITypeSymbol a return inferredTypes; } - private static ITypeSymbol[]? InferClassTypesFromMethodDataSource(Compilation compilation, TestMethodMetadata testMethod, AttributeData mdsAttr) + private static ITypeSymbol[]? InferClassTypesFromMethodDataSource(TestMethodMetadata testMethod, AttributeData mdsAttr) { if (mdsAttr.ConstructorArguments.Length == 0) { @@ -4687,23 +4386,6 @@ private static void ProcessTypeForGenerics(ITypeSymbol paramType, ITypeSymbol ac } } - private static bool ValidateTypeConstraints(INamedTypeSymbol classType, ITypeSymbol[] typeArguments) - { - // Validate constraints for a generic class - if (!classType.IsGenericType) - { - return true; - } - - var typeParams = classType.TypeParameters; - if (typeParams.Length != typeArguments.Length) - { - return false; - } - - return ValidateTypeParameterConstraints(typeParams, typeArguments); - } - private static bool ValidateTypeConstraints(IMethodSymbol method, ITypeSymbol[] typeArguments) { // Only validate method type parameters here - class type parameters are validated separately @@ -4788,7 +4470,6 @@ private static void GenerateConcreteTestMetadata( ITypeSymbol[] typeArguments, AttributeData? specificArgumentsAttribute = null) { - var compilation = testMethod.Context!.Value.SemanticModel.Compilation; var methodName = testMethod.MethodSymbol.Name; // Separate class type arguments from method type arguments @@ -5025,7 +4706,7 @@ private static void GenerateConcreteMetadataWithFilteredDataSources( } } - GenerateDependencies(writer, compilation, methodSymbol); + GenerateDependencies(writer, methodSymbol); // Generate attribute factory with filtered attributes var filteredAttributes = new List(); @@ -5372,7 +5053,7 @@ private static void GenerateConcreteTestMetadataForNonGeneric( // Generate metadata - GenerateDependencies(writer, compilation, testMethod.MethodSymbol); + GenerateDependencies(writer, testMethod.MethodSymbol); // Generate attribute factory writer.AppendLine("AttributeFactory = static () =>"); diff --git a/TUnit.Core.SourceGenerator/Models/DiagnosticContext.cs b/TUnit.Core.SourceGenerator/Models/DiagnosticContext.cs deleted file mode 100644 index 0660f97f96..0000000000 --- a/TUnit.Core.SourceGenerator/Models/DiagnosticContext.cs +++ /dev/null @@ -1,92 +0,0 @@ -using Microsoft.CodeAnalysis; - -namespace TUnit.Core.SourceGenerator.Models; - -/// -/// Context for collecting and reporting diagnostics during source generation -/// -public class DiagnosticContext -{ - private readonly List _diagnostics = - [ - ]; - private readonly SourceProductionContext _sourceProductionContext; - - public DiagnosticContext(SourceProductionContext sourceProductionContext) - { - _sourceProductionContext = sourceProductionContext; - } - - /// - /// Reports a diagnostic immediately - /// - public void ReportDiagnostic(Diagnostic diagnostic) - { - _sourceProductionContext.ReportDiagnostic(diagnostic); - _diagnostics.Add(diagnostic); - } - - /// - /// Creates and reports an error diagnostic - /// - public void ReportError(string id, string title, string message, Location? location = null) - { - var diagnostic = Diagnostic.Create( - new DiagnosticDescriptor( - id, - title, - message, - "TUnit", - DiagnosticSeverity.Error, - isEnabledByDefault: true), - location ?? Location.None); - - ReportDiagnostic(diagnostic); - } - - /// - /// Creates and reports a warning diagnostic - /// - public void ReportWarning(string id, string title, string message, Location? location = null) - { - var diagnostic = Diagnostic.Create( - new DiagnosticDescriptor( - id, - title, - message, - "TUnit", - DiagnosticSeverity.Warning, - isEnabledByDefault: true), - location ?? Location.None); - - ReportDiagnostic(diagnostic); - } - - /// - /// Creates and reports an info diagnostic - /// - public void ReportInfo(string id, string title, string message, Location? location = null) - { - var diagnostic = Diagnostic.Create( - new DiagnosticDescriptor( - id, - title, - message, - "TUnit", - DiagnosticSeverity.Info, - isEnabledByDefault: true), - location ?? Location.None); - - ReportDiagnostic(diagnostic); - } - - /// - /// Gets all diagnostics that have been reported - /// - public IReadOnlyList GetDiagnostics() => _diagnostics.AsReadOnly(); - - /// - /// Checks if any errors have been reported - /// - public bool HasErrors => _diagnostics.Exists(d => d.Severity == DiagnosticSeverity.Error); -} diff --git a/TUnit.Core.SourceGenerator/Models/Extracted/PropertyDataModel.cs b/TUnit.Core.SourceGenerator/Models/Extracted/PropertyDataModel.cs deleted file mode 100644 index 2618bf09cf..0000000000 --- a/TUnit.Core.SourceGenerator/Models/Extracted/PropertyDataModel.cs +++ /dev/null @@ -1,63 +0,0 @@ -namespace TUnit.Core.SourceGenerator.Models.Extracted; - -/// -/// Primitive representation of a property with a data source. -/// Contains only strings and primitives - no Roslyn symbols. -/// Used for both instance and static property injection. -/// -public sealed class PropertyDataModel : IEquatable -{ - // Property identity - public required string PropertyName { get; init; } - public required string PropertyTypeName { get; init; } - public required string ContainingTypeName { get; init; } - public required string MinimalContainingTypeName { get; init; } - public required string Namespace { get; init; } - public required string AssemblyName { get; init; } - - // Property characteristics - public required bool IsStatic { get; init; } - public required bool HasPublicGetter { get; init; } - public required bool HasPublicSetter { get; init; } - - // Data source - public required DataSourceModel DataSource { get; init; } - - // Attributes - public required EquatableArray PropertyAttributes { get; init; } - - public bool Equals(PropertyDataModel? other) - { - if (other is null) - { - return false; - } - - if (ReferenceEquals(this, other)) - { - return true; - } - - return PropertyName == other.PropertyName - && ContainingTypeName == other.ContainingTypeName - && IsStatic == other.IsStatic - && DataSource.Equals(other.DataSource); - } - - public override bool Equals(object? obj) - { - return Equals(obj as PropertyDataModel); - } - - public override int GetHashCode() - { - unchecked - { - var hash = PropertyName.GetHashCode(); - hash = (hash * 397) ^ ContainingTypeName.GetHashCode(); - hash = (hash * 397) ^ IsStatic.GetHashCode(); - hash = (hash * 397) ^ DataSource.GetHashCode(); - return hash; - } - } -} diff --git a/TUnit.Core.SourceGenerator/Models/GenericTestRegistration.cs b/TUnit.Core.SourceGenerator/Models/GenericTestRegistration.cs deleted file mode 100644 index 6fd7f9d861..0000000000 --- a/TUnit.Core.SourceGenerator/Models/GenericTestRegistration.cs +++ /dev/null @@ -1,30 +0,0 @@ -using System.Collections.Immutable; -using Microsoft.CodeAnalysis; - -namespace TUnit.Core.SourceGenerator.Models; - -/// -/// Model representing a generic test class registration for AOT support. -/// -public record GenericTestRegistration -{ - /// - /// The generic type definition (e.g., MyTest<>). - /// - public required INamedTypeSymbol GenericTypeDefinition { get; init; } - - /// - /// The concrete type arguments (e.g., int, string). - /// - public required ImmutableArray TypeArguments { get; init; } - - /// - /// The fully qualified name of the concrete type (e.g., MyTest). - /// - public required string ConcreteTypeName { get; init; } - - /// - /// The constructed concrete type. - /// - public required INamedTypeSymbol ConcreteType { get; init; } -} diff --git a/TUnit.Core.SourceGenerator/Models/PropertyInjectionContext.cs b/TUnit.Core.SourceGenerator/Models/PropertyInjectionContext.cs deleted file mode 100644 index 3e404365cf..0000000000 --- a/TUnit.Core.SourceGenerator/Models/PropertyInjectionContext.cs +++ /dev/null @@ -1,43 +0,0 @@ -using Microsoft.CodeAnalysis; - -namespace TUnit.Core.SourceGenerator.Models; - -/// -/// Context for property injection generation containing all necessary information -/// -public class PropertyInjectionContext : IEquatable -{ - public required INamedTypeSymbol ClassSymbol { get; init; } - public required string ClassName { get; init; } - public required string SafeClassName { get; init; } - public DiagnosticContext? DiagnosticContext { get; init; } - - public bool Equals(PropertyInjectionContext? other) - { - if (ReferenceEquals(null, other)) - return false; - if (ReferenceEquals(this, other)) - return true; - - return SymbolEqualityComparer.Default.Equals(ClassSymbol, other.ClassSymbol) && - ClassName == other.ClassName && - SafeClassName == other.SafeClassName; - // Note: DiagnosticContext is not included in equality as it's contextual/runtime state - } - - public override bool Equals(object? obj) - { - return Equals(obj as PropertyInjectionContext); - } - - public override int GetHashCode() - { - unchecked - { - var hashCode = SymbolEqualityComparer.Default.GetHashCode(ClassSymbol); - hashCode = (hashCode * 397) ^ ClassName.GetHashCode(); - hashCode = (hashCode * 397) ^ SafeClassName.GetHashCode(); - return hashCode; - } - } -} \ No newline at end of file diff --git a/TUnit.Core.SourceGenerator/Models/StaticClassDataSourceInjectorModel.cs b/TUnit.Core.SourceGenerator/Models/StaticClassDataSourceInjectorModel.cs deleted file mode 100644 index 13e7e6e335..0000000000 --- a/TUnit.Core.SourceGenerator/Models/StaticClassDataSourceInjectorModel.cs +++ /dev/null @@ -1,9 +0,0 @@ -namespace TUnit.Core.SourceGenerator.Models; - -public record StaticClassDataSourceInjectorModel -{ - public required string FullyQualifiedTypeName { get; init; } - public required string PropertyName { get; init; } - public required string InjectableType { get; init; } - public required string MinimalTypeName { get; set; } -} diff --git a/TUnit.Core.SourceGenerator/Models/TestDefinitionContext.cs b/TUnit.Core.SourceGenerator/Models/TestDefinitionContext.cs deleted file mode 100644 index 18542df551..0000000000 --- a/TUnit.Core.SourceGenerator/Models/TestDefinitionContext.cs +++ /dev/null @@ -1,222 +0,0 @@ -using Microsoft.CodeAnalysis; - -namespace TUnit.Core.SourceGenerator.Models; - -/// -/// Context used when building individual test definitions. -/// This is a subset of TestGenerationContext focused on what's needed for a single test. -/// -public class TestDefinitionContext : IEquatable -{ - public required TestMetadataGenerationContext GenerationContext { get; init; } - public required AttributeData? ClassDataAttribute { get; init; } - public required AttributeData? MethodDataAttribute { get; init; } - public required int TestIndex { get; init; } - public required int RepeatIndex { get; init; } - - /// - /// Creates contexts for all test definitions based on data attributes - /// - public static IEnumerable CreateContexts(TestMetadataGenerationContext generationContext) - { - var testInfo = generationContext.TestInfo; - - // Get all data source attributes that can be handled at compile time - var classDataAttrs = testInfo.TypeSymbol.GetAttributes() - .Where(attr => IsCompileTimeDataSourceAttribute(attr)) - .ToList(); - - var methodDataAttrs = testInfo.MethodSymbol.GetAttributes() - .Where(attr => IsCompileTimeDataSourceAttribute(attr)) - .ToList(); - - // Extract repeat count - var repeatCount = ExtractRepeatCount(testInfo.MethodSymbol); - if (repeatCount == 0) - { - repeatCount = 1; // Default to 1 if no repeat attribute - } - - var testIndex = 0; - - // Convert to arrays once upfront for better performance (avoid repeated .Any() calls and enumeration) - var classDataArray = classDataAttrs.Count > 0 ? classDataAttrs.ToArray() : []; - var methodDataArray = methodDataAttrs.Count > 0 ? methodDataAttrs.ToArray() : []; - var hasClassData = classDataArray.Length > 0; - var hasMethodData = methodDataArray.Length > 0; - - if (!hasClassData && !hasMethodData) - { - for (var repeatIndex = 0; repeatIndex < repeatCount; repeatIndex++) - { - yield return new TestDefinitionContext - { - GenerationContext = generationContext, - ClassDataAttribute = null, - MethodDataAttribute = null, - TestIndex = testIndex++, - RepeatIndex = repeatIndex - }; - } - yield break; - } - - if (hasClassData && !hasMethodData) - { - // Use array indexing instead of foreach for slightly better performance - for (var i = 0; i < classDataArray.Length; i++) - { - for (var repeatIndex = 0; repeatIndex < repeatCount; repeatIndex++) - { - yield return new TestDefinitionContext - { - GenerationContext = generationContext, - ClassDataAttribute = classDataArray[i], - MethodDataAttribute = null, - TestIndex = testIndex++, - RepeatIndex = repeatIndex - }; - } - } - } - else if (!hasClassData && hasMethodData) - { - // Use array indexing instead of foreach for slightly better performance - for (var i = 0; i < methodDataArray.Length; i++) - { - for (var repeatIndex = 0; repeatIndex < repeatCount; repeatIndex++) - { - yield return new TestDefinitionContext - { - GenerationContext = generationContext, - ClassDataAttribute = null, - MethodDataAttribute = methodDataArray[i], - TestIndex = testIndex++, - RepeatIndex = repeatIndex - }; - } - } - } - // If we have both class and method data - create cartesian product with array indexing - else - { - // Use array indexing for cartesian product for better performance - for (var i = 0; i < classDataArray.Length; i++) - { - for (var j = 0; j < methodDataArray.Length; j++) - { - for (var repeatIndex = 0; repeatIndex < repeatCount; repeatIndex++) - { - yield return new TestDefinitionContext - { - GenerationContext = generationContext, - ClassDataAttribute = classDataArray[i], - MethodDataAttribute = methodDataArray[j], - TestIndex = testIndex++, - RepeatIndex = repeatIndex - }; - } - } - } - } - } - - private static int ExtractRepeatCount(IMethodSymbol methodSymbol) - { - var repeatAttribute = methodSymbol.GetAttributes() - .FirstOrDefault(a => a.AttributeClass?.Name == "RepeatAttribute"); - - if (repeatAttribute is { ConstructorArguments.Length: > 0 }) - { - if (repeatAttribute.ConstructorArguments[0].Value is int count) - { - return count; - } - } - - return 0; - } - - private static bool IsCompileTimeDataSourceAttribute(AttributeData attr) - { - var attrName = attr.AttributeClass?.Name; - - // These can be handled at compile time through code generation: - // - ArgumentsAttribute (direct data) - // - MethodDataSourceAttribute (generate lambda to call method) - // - Attributes inheriting from AsyncDataSourceGeneratorAttribute (generate lambda to instantiate and call) - - if (attrName is "ArgumentsAttribute" or "MethodDataSourceAttribute") - { - return true; - } - - // Check if it inherits from AsyncDataSourceGeneratorAttribute - var baseType = attr.AttributeClass?.BaseType; - while (baseType != null) - { - if (baseType.Name == "AsyncDataSourceGeneratorAttribute") - { - return true; - } - baseType = baseType.BaseType; - } - - return false; - } - - public bool Equals(TestDefinitionContext? other) - { - if (ReferenceEquals(null, other)) - return false; - if (ReferenceEquals(this, other)) - return true; - - return GenerationContext.Equals(other.GenerationContext) && - AttributeDataEquals(ClassDataAttribute, other.ClassDataAttribute) && - AttributeDataEquals(MethodDataAttribute, other.MethodDataAttribute) && - TestIndex == other.TestIndex && - RepeatIndex == other.RepeatIndex; - } - - public override bool Equals(object? obj) - { - return Equals(obj as TestDefinitionContext); - } - - public override int GetHashCode() - { - unchecked - { - var hashCode = GenerationContext.GetHashCode(); - hashCode = (hashCode * 397) ^ AttributeDataGetHashCode(ClassDataAttribute); - hashCode = (hashCode * 397) ^ AttributeDataGetHashCode(MethodDataAttribute); - hashCode = (hashCode * 397) ^ TestIndex; - hashCode = (hashCode * 397) ^ RepeatIndex; - return hashCode; - } - } - - private static bool AttributeDataEquals(AttributeData? x, AttributeData? y) - { - if (ReferenceEquals(x, y)) return true; - if (x is null || y is null) return false; - - return SymbolEqualityComparer.Default.Equals(x.AttributeClass, y.AttributeClass) && - x.ConstructorArguments.Length == y.ConstructorArguments.Length && - x.ConstructorArguments.Zip(y.ConstructorArguments, (a, b) => TypedConstantEquals(a, b)).All(eq => eq); - } - - private static bool TypedConstantEquals(TypedConstant x, TypedConstant y) - { - if (x.Kind != y.Kind) return false; - if (!SymbolEqualityComparer.Default.Equals(x.Type, y.Type)) return false; - return Equals(x.Value, y.Value); - } - - private static int AttributeDataGetHashCode(AttributeData? attr) - { - if (attr is null) return 0; - return SymbolEqualityComparer.Default.GetHashCode(attr.AttributeClass); - } -} diff --git a/TUnit.Core.SourceGenerator/Models/TestHookCollectionDataModel.cs b/TUnit.Core.SourceGenerator/Models/TestHookCollectionDataModel.cs deleted file mode 100644 index aada64c476..0000000000 --- a/TUnit.Core.SourceGenerator/Models/TestHookCollectionDataModel.cs +++ /dev/null @@ -1,24 +0,0 @@ -namespace TUnit.Core.SourceGenerator.Models; - -public record TestHookCollectionDataModel(IEnumerable HooksDataModels) -{ - public virtual bool Equals(TestHookCollectionDataModel? other) - { - if (ReferenceEquals(null, other)) - { - return false; - } - - if (ReferenceEquals(this, other)) - { - return true; - } - - return HooksDataModels.SequenceEqual(other.HooksDataModels); - } - - public override int GetHashCode() - { - return HooksDataModels.GetHashCode(); - } -} diff --git a/TUnit.Core.SourceGenerator/Models/TypeWithDataSourceProperties.cs b/TUnit.Core.SourceGenerator/Models/TypeWithDataSourceProperties.cs deleted file mode 100644 index 9b52e5e691..0000000000 --- a/TUnit.Core.SourceGenerator/Models/TypeWithDataSourceProperties.cs +++ /dev/null @@ -1,23 +0,0 @@ -using Microsoft.CodeAnalysis; - -namespace TUnit.Core.SourceGenerator.Models; - -public struct TypeWithDataSourceProperties -{ - public INamedTypeSymbol TypeSymbol { get; init; } - public List Properties { get; init; } -} - -public sealed class TypeWithDataSourcePropertiesComparer : IEqualityComparer -{ - public bool Equals(TypeWithDataSourceProperties x, TypeWithDataSourceProperties y) - { - // Compare based on the type symbol - this handles partial classes correctly - return SymbolEqualityComparer.Default.Equals(x.TypeSymbol, y.TypeSymbol); - } - - public int GetHashCode(TypeWithDataSourceProperties obj) - { - return SymbolEqualityComparer.Default.GetHashCode(obj.TypeSymbol); - } -}