diff --git a/TUnit.Core.SourceGenerator.Roslyn414/TUnit.Core.SourceGenerator.Roslyn414.csproj b/TUnit.Core.SourceGenerator.Roslyn414/TUnit.Core.SourceGenerator.Roslyn414.csproj index 1f2708cc24..9e5718edba 100644 --- a/TUnit.Core.SourceGenerator.Roslyn414/TUnit.Core.SourceGenerator.Roslyn414.csproj +++ b/TUnit.Core.SourceGenerator.Roslyn414/TUnit.Core.SourceGenerator.Roslyn414.csproj @@ -4,6 +4,12 @@ 4.14 + + + CodeGenerators\Writers\Hooks + + + diff --git a/TUnit.Core.SourceGenerator.Roslyn44/TUnit.Core.SourceGenerator.Roslyn44.csproj b/TUnit.Core.SourceGenerator.Roslyn44/TUnit.Core.SourceGenerator.Roslyn44.csproj index a321b8fc20..109e28962a 100644 --- a/TUnit.Core.SourceGenerator.Roslyn44/TUnit.Core.SourceGenerator.Roslyn44.csproj +++ b/TUnit.Core.SourceGenerator.Roslyn44/TUnit.Core.SourceGenerator.Roslyn44.csproj @@ -4,6 +4,12 @@ 4.4 + + + CodeGenerators\Writers\Hooks + + + diff --git a/TUnit.Core.SourceGenerator.Roslyn47/TUnit.Core.SourceGenerator.Roslyn47.csproj b/TUnit.Core.SourceGenerator.Roslyn47/TUnit.Core.SourceGenerator.Roslyn47.csproj index 7a9abf09c6..c4571db44e 100644 --- a/TUnit.Core.SourceGenerator.Roslyn47/TUnit.Core.SourceGenerator.Roslyn47.csproj +++ b/TUnit.Core.SourceGenerator.Roslyn47/TUnit.Core.SourceGenerator.Roslyn47.csproj @@ -4,6 +4,12 @@ 4.7 + + + CodeGenerators\Writers\Hooks + + + diff --git a/TUnit.Core.SourceGenerator/Builders/ITestDefinitionBuilder.cs b/TUnit.Core.SourceGenerator/Builders/ITestDefinitionBuilder.cs deleted file mode 100644 index 17dd95f525..0000000000 --- a/TUnit.Core.SourceGenerator/Builders/ITestDefinitionBuilder.cs +++ /dev/null @@ -1,19 +0,0 @@ -using TUnit.Core.SourceGenerator.Models; - -namespace TUnit.Core.SourceGenerator.Builders; - -/// -/// Interface for building test definitions -/// -internal interface ITestDefinitionBuilder -{ - /// - /// Determines if this builder can handle the given context - /// - bool CanBuild(TestMetadataGenerationContext context); - - /// - /// Builds test definitions and writes them to the code writer - /// - void BuildTestDefinitions(CodeWriter writer, TestMetadataGenerationContext context); -} diff --git a/TUnit.Core.SourceGenerator/CodeGenerationHelpers.cs b/TUnit.Core.SourceGenerator/CodeGenerationHelpers.cs index daac82b9e0..4ebec87f72 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerationHelpers.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerationHelpers.cs @@ -2,9 +2,7 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using TUnit.Core.SourceGenerator.CodeGenerators.Helpers; -using TUnit.Core.SourceGenerator.CodeGenerators.Writers; using TUnit.Core.SourceGenerator.Extensions; -using TUnit.Core.SourceGenerator.Utilities; namespace TUnit.Core.SourceGenerator; @@ -13,68 +11,6 @@ namespace TUnit.Core.SourceGenerator; /// internal static class CodeGenerationHelpers { - /// - /// Generates C# code for a ParameterMetadata array from method parameters. - /// - public static string GenerateParameterMetadataArray(IMethodSymbol method) - { - if (method.Parameters.Length == 0) - { - return "global::System.Array.Empty()"; - } - - using var writer = new CodeWriter("", includeHeader: false); - writer.SetIndentLevel(2); - using (writer.BeginArrayInitializer("new global::TUnit.Core.ParameterMetadata[]")) - { - foreach (var param in method.Parameters) - { - var parameterIndex = method.Parameters.IndexOf(param); - var containsTypeParam = ContainsTypeParameter(param.Type); - var typeForConstructor = containsTypeParam ? "object" : param.Type.GloballyQualified(); - - using (writer.BeginObjectInitializer($"new global::TUnit.Core.ParameterMetadata(typeof({typeForConstructor}))", ",")) - { - writer.AppendLine($"Name = \"{param.Name}\","); - writer.AppendLine($"TypeInfo = {GenerateTypeInfo(param.Type)},"); - writer.AppendLine($"IsNullable = {param.Type.IsNullable().ToString().ToLowerInvariant()},"); - var paramTypesArray = GenerateParameterTypesArray(method); - if (paramTypesArray == "null") - { - writer.AppendLine($"ReflectionInfo = typeof({method.ContainingType.GloballyQualified()}).GetMethods(global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.Static).FirstOrDefault(m => m.Name == \"{method.Name}\" && m.GetParameters().Length == {method.Parameters.Length})?.GetParameters()[{parameterIndex}],"); - } - else - { - writer.AppendLine($"ReflectionInfo = typeof({method.ContainingType.GloballyQualified()}).GetMethod(\"{method.Name}\", global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.Static, null, {paramTypesArray}, null)!.GetParameters()[{parameterIndex}],"); - } - - // Generate cached data source attributes for AOT compatibility - // Include both IDataSourceAttribute and IDataSourceMemberAttribute implementations - var dataSourceAttributes = param.GetAttributes() - .Where(attr => attr.AttributeClass != null && - attr.AttributeClass.AllInterfaces.Any(i => - i.Name == "IDataSourceAttribute" || i.Name == "IDataSourceMemberAttribute")) - .ToArray(); - - if (dataSourceAttributes.Length > 0) - { - writer.AppendLine($"CachedDataSourceAttributes = new global::System.Attribute[]"); - writer.AppendLine("{"); - writer.SetIndentLevel(3); - foreach (var attr in dataSourceAttributes) - { - var attrCode = GenerateAttributeInstantiation(attr, method.Parameters); - writer.AppendLine($"{attrCode},"); - } - writer.SetIndentLevel(2); - writer.Append("}"); - } - } - } - } - return writer.ToString().TrimEnd(); // Trim trailing newline for inline use - } - /// /// Generates direct instantiation code for attributes. /// @@ -289,7 +225,6 @@ private static bool IsParamsArrayArgument(AttributeData attr) return typeName is "global::TUnit.Core.ArgumentsAttribute" or "global::TUnit.Core.InlineDataAttribute"; } - /// /// Determines if an attribute should be excluded from metadata. /// @@ -320,120 +255,6 @@ public static bool ContainsTypeParameter(ITypeSymbol type) return false; } - /// - /// Gets a safe type name for use in typeof() expressions. - /// Returns "object" only for actual type parameters or types containing them. - /// Returns open generic forms (e.g., List<>) for generic type definitions. - /// - - - /// - /// Generates C# code for PropertyMetadata array from class properties. - /// - public static string GeneratePropertyMetadataArray(INamedTypeSymbol typeSymbol) - { - var properties = typeSymbol.GetMembers() - .OfType() - .Where(p => p.DeclaredAccessibility == Accessibility.Public && !p.IsStatic) - .ToList(); - - if (properties.Count == 0) - { - return "global::System.Array.Empty()"; - } - - using var writer = new CodeWriter("", includeHeader: false); - writer.SetIndentLevel(2); - using (writer.BeginArrayInitializer("new global::TUnit.Core.PropertyMetadata[]")) - { - foreach (var prop in properties) - { - using (writer.BeginObjectInitializer("new global::TUnit.Core.PropertyMetadata", ",")) - { - writer.AppendLine($"Name = \"{prop.Name}\","); - writer.AppendLine($"Type = typeof({prop.Type.GloballyQualified()}),"); - writer.AppendLine($"ReflectionInfo = typeof({typeSymbol.GloballyQualified()}).GetProperty(\"{prop.Name}\"),"); - writer.AppendLine("IsStatic = false,"); - writer.AppendLine($"IsNullable = {prop.Type.IsNullable().ToString().ToLowerInvariant()},"); - writer.AppendLine($"Getter = obj => ((({typeSymbol.GloballyQualified()})obj).{prop.Name}),"); - writer.AppendLine("ClassMetadata = null!,"); - writer.AppendLine("ContainingTypeMetadata = null!"); - } - } - } - return writer.ToString().TrimEnd(); // Trim trailing newline for inline use - } - - /// - /// Generates C# code for ConstructorMetadata array from class constructors. - /// - - /// - /// Generates C# code for class-level data source providers. - /// - public static string GenerateClassDataSourceProviders(INamedTypeSymbol typeSymbol) - { - var dataSourceAttributes = typeSymbol.GetAttributes() - .Where(attr => IsDataSourceAttribute(attr)) - .ToList(); - - if (dataSourceAttributes.Count == 0) - { - return "global::System.Array.Empty()"; - } - - using var writer = new CodeWriter("", includeHeader: false); - writer.SetIndentLevel(2); - using (writer.BeginArrayInitializer("new global::TUnit.Core.TestDataSource[]")) - { - foreach (var attr in dataSourceAttributes) - { - var providerCode = GenerateDataSourceProvider(attr, typeSymbol); - if (!string.IsNullOrEmpty(providerCode)) - { - writer.AppendLine($"{providerCode},"); - } - } - } - return writer.ToString().TrimEnd(); // Trim trailing newline for inline use - } - - /// - /// Generates C# code for method-level data source providers. - /// - public static string GenerateMethodDataSourceProviders(IMethodSymbol methodSymbol) - { - var dataSourceAttributes = methodSymbol.GetAttributes() - .Where(attr => IsDataSourceAttribute(attr)) - .ToList(); - - // Also check method parameters for data attributes - foreach (var param in methodSymbol.Parameters) - { - dataSourceAttributes.AddRange(param.GetAttributes().Where(IsDataSourceAttribute)); - } - - if (dataSourceAttributes.Count == 0) - { - return "global::System.Array.Empty()"; - } - - using var writer = new CodeWriter("", includeHeader: false); - writer.SetIndentLevel(2); - using (writer.BeginArrayInitializer("new global::TUnit.Core.TestDataSource[]")) - { - foreach (var attr in dataSourceAttributes) - { - var providerCode = GenerateDataSourceProvider(attr, methodSymbol.ContainingType); - if (!string.IsNullOrEmpty(providerCode)) - { - writer.AppendLine($"{providerCode},"); - } - } - } - return writer.ToString().TrimEnd(); // Trim trailing newline for inline use - } - /// /// Determines if an attribute is a data source attribute. /// @@ -448,114 +269,6 @@ private static bool IsDataSourceAttribute(AttributeData attr) return attr.AttributeClass.AllInterfaces.Any(i => i.GloballyQualified() == "global::TUnit.Core.IDataSourceAttribute"); } - /// - /// Generates a data source provider instance based on the attribute type. - /// - private static string GenerateDataSourceProvider(AttributeData attr, INamedTypeSymbol containingType) - { - var fullName = attr.AttributeClass!.GloballyQualified(); - - switch (fullName) - { - case "TUnit.Core.ArgumentsAttribute": - case "TUnit.Core.InlineDataAttribute": - return GenerateInlineDataProvider(attr); - - case "TUnit.Core.MethodDataSourceAttribute": - return GenerateMethodDataSourceProvider(attr, containingType); - - case "TUnit.Core.PropertyDataSourceAttribute": - return GeneratePropertyDataSourceProvider(attr, containingType); - - default: - // For custom IDataSourceAttribute implementations (including ClassDataSourceAttribute) - return GenerateCustomDataProvider(attr); - } - } - - private static string GenerateInlineDataProvider(AttributeData attr) - { - using var writer = new CodeWriter("", includeHeader: false); - writer.Append("new global::TUnit.Core.StaticTestDataSource(new object?[][] { new object?[] { "); - - var args = attr.ConstructorArguments.Select(arg => TypedConstantParser.GetRawTypedConstantValue(arg)).ToList(); - writer.Append(string.Join(", ", args)); - writer.Append(" } })"); - return writer.ToString().Trim(); - } - - private static string GenerateMethodDataSourceProvider(AttributeData attr, INamedTypeSymbol containingType) - { - if (attr.ConstructorArguments.Length == 0) - { - return ""; - } - - var methodName = attr.ConstructorArguments[0].Value?.ToString() ?? ""; - var isShared = attr.NamedArguments.FirstOrDefault(na => na.Key == "Shared").Value.Value as bool? ?? false; - - // Try to determine if this can be optimized for AOT - var method = FindDataSourceMethod(methodName, containingType); - if (method != null && ShouldUseAotOptimizedDataSource(method)) - { - // Generate code that uses a more AOT-friendly approach - return GenerateAotOptimizedDataSource(methodName, containingType, isShared); - } - - // Fall back to DynamicTestDataSource for complex cases - return $"new global::TUnit.Core.DynamicTestDataSource({isShared.ToString().ToLowerInvariant()}) {{ SourceType = typeof({containingType.GloballyQualified()}), SourceMemberName = \"{methodName}\" }}"; - } - - private static string GenerateParameterTypesArray(IMethodSymbol method) - { - if (method.Parameters.Length == 0) - { - return "global::System.Type.EmptyTypes"; - } - - if (method.Parameters.Any(p => ContainsTypeParameter(p.Type))) - { - return "null"; - } - - var parameterTypes = method.Parameters - .Select(p => $"typeof({p.Type.GloballyQualified()})") - .ToArray(); - - return $"new global::System.Type[] {{ {string.Join(", ", parameterTypes)} }}"; - } - - - private static string GeneratePropertyDataSourceProvider(AttributeData attr, INamedTypeSymbol containingType) - { - if (attr.ConstructorArguments.Length == 0) - { - return ""; - } - - var propertyName = attr.ConstructorArguments[0].Value?.ToString() ?? ""; - var isShared = attr.NamedArguments.FirstOrDefault(na => na.Key == "Shared").Value.Value as bool? ?? false; - - // Check if we can use AOT-friendly approach for property data sources - var property = containingType.GetMembers(propertyName).OfType().FirstOrDefault(); - if (property != null && ShouldUseAotOptimizedPropertyDataSource(property)) - { - return GenerateAotOptimizedPropertyDataSource(propertyName, containingType, isShared); - } - - return $"new global::TUnit.Core.DynamicTestDataSource({isShared.ToString().ToLowerInvariant()}) {{ SourceType = typeof({containingType.GloballyQualified()}), SourceMemberName = \"{propertyName}\" }}"; - } - - private static string GenerateCustomDataProvider(AttributeData attr) - { - // For custom data attributes that implement IDataSourceAttribute (including AsyncDataSourceGeneratorAttribute), - // we need to instantiate the attribute and use it directly - var writer = new CodeWriter(); - AttributeWriter.WriteAttributeWithoutSyntax(writer, attr); - return writer.ToString(); - } - - /// /// Generates all test-related attributes for the TestMetadata.AttributesByType field as a dictionary. /// @@ -668,122 +381,4 @@ private static int GetGenericParameterPosition(ITypeParameterSymbol typeParamete } return 0; } - - private static string GetGenericTypeDefinitionName(INamedTypeSymbol namedType) - { - // Get the unbound generic type (e.g., List`1) - var unboundType = namedType.ConstructUnboundGenericType(); - return GetAssemblyQualifiedName(unboundType); - } - - private static string GetAssemblyQualifiedName(ITypeSymbol typeSymbol) - { - // Build assembly qualified name - var typeName = typeSymbol.ToDisplayString(DisplayFormats.FullyQualifiedGenericWithoutGlobalPrefix); - - if (typeSymbol.ContainingAssembly.Name is "System.Private.CoreLib" or "mscorlib") - { - return $"{typeName}, System.Private.CoreLib"; - } - - return $"{typeName}, {typeSymbol.ContainingAssembly.Name}"; - } - - #region Compile-Time Data Source Resolution - - /// - /// Finds a data source method in the containing type. - /// - private static IMethodSymbol? FindDataSourceMethod(string methodName, INamedTypeSymbol containingType) - { - return containingType.GetMembers(methodName) - .OfType() - .FirstOrDefault(m => m.Parameters.Length == 0); - } - - /// - /// Determines if a method should use AOT-optimized data source generation. - /// - private static bool ShouldUseAotOptimizedDataSource(IMethodSymbol method) - { - if (!method.IsStatic) - { - return false; - } - if (method.Parameters.Length > 0) - { - return false; - } - - var returnType = method.ReturnType; - - if (returnType is INamedTypeSymbol namedType) - { - var typeString = namedType.ToDisplayString(); - - return typeString.Contains("IEnumerable<") || - typeString.Contains("ICollection<") || - typeString.Contains("List<") || - typeString == "object[][]" || - typeString.Contains("object[]"); - } - - return false; - } - - /// - /// Generates AOT-optimized data source code that avoids reflection. - /// - private static string GenerateAotOptimizedDataSource(string methodName, INamedTypeSymbol containingType, bool isShared) - { - return $"new global::TUnit.Core.AotFriendlyTestDataSource({isShared.ToString().ToLowerInvariant()}) {{ " + - $"MethodInvoker = () => {containingType.GloballyQualified()}.{methodName}(), " + - $"SourceType = typeof({containingType.GloballyQualified()}), " + - $"SourceMemberName = \"{methodName}\" }}"; - } - - - /// - /// Generates AOT-optimized property data source code that avoids reflection. - /// - private static string GenerateAotOptimizedPropertyDataSource(string propertyName, INamedTypeSymbol containingType, bool isShared) - { - return $"new global::TUnit.Core.AotFriendlyTestDataSource({isShared.ToString().ToLowerInvariant()}) {{ " + - $"MethodInvoker = () => new {containingType.GloballyQualified()}().{propertyName}, " + - $"SourceType = typeof({containingType.GloballyQualified()}), " + - $"SourceMemberName = \"{propertyName}\" }}"; - } - - /// - /// Determines if a property should use AOT-optimized data source generation. - /// - private static bool ShouldUseAotOptimizedPropertyDataSource(IPropertySymbol property) - { - if (!property.IsStatic) - { - var containingType = property.ContainingType; - var hasParameterlessConstructor = containingType.Constructors.Any(c => c.Parameters.Length == 0); - if (!hasParameterlessConstructor) - { - return false; - } - } - - var returnType = property.Type; - - if (returnType is INamedTypeSymbol namedType) - { - var typeString = namedType.ToDisplayString(); - - return typeString.Contains("IEnumerable<") || - typeString.Contains("ICollection<") || - typeString.Contains("List<") || - typeString == "object[][]" || - typeString.Contains("object[]"); - } - - return false; - } - - #endregion } diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TupleArgumentHelper.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TupleArgumentHelper.cs index 37e93199f1..ec78c7f4d3 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TupleArgumentHelper.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TupleArgumentHelper.cs @@ -1,41 +1,10 @@ using Microsoft.CodeAnalysis; using TUnit.Core.SourceGenerator.Extensions; -using TUnit.Core.SourceGenerator.Models; namespace TUnit.Core.SourceGenerator.CodeGenerators.Helpers; public static class TupleArgumentHelper { - public static List GenerateArgumentAccess(ITypeSymbol parameterType, string argumentsArrayName, int baseIndex) - { - var argumentExpressions = new List(); - - // For method parameters, tuples are NOT supported - the data source - // must return already unpacked values matching the method signature - var castExpression = $"global::TUnit.Core.Helpers.CastHelper.Cast<{parameterType.GloballyQualified()}>({argumentsArrayName}[{baseIndex}])"; - argumentExpressions.Add(castExpression); - - return argumentExpressions; - } - - /// The types of all constructor parameters - /// The name of the arguments array (e.g., "args") - /// A list of argument access expressions for the constructor - public static List GenerateConstructorArgumentAccess(IList parameterTypes, string argumentsArrayName) - { - var argumentExpressions = new List(); - - // Data sources already provide unwrapped arguments, so we just access by index - for (var i = 0; i < parameterTypes.Count; i++) - { - var parameterType = parameterTypes[i]; - var castExpression = $"global::TUnit.Core.Helpers.CastHelper.Cast<{parameterType.GloballyQualified()}>({argumentsArrayName}[{i}])"; - argumentExpressions.Add(castExpression); - } - - return argumentExpressions; - } - /// /// Generates method invocation arguments. /// @@ -45,17 +14,17 @@ public static List GenerateConstructorArgumentAccess(IList public static string GenerateMethodInvocationArguments(IList parameters, string argumentsArrayName) { var allArguments = new List(); - + for (var i = 0; i < parameters.Count; i++) { var parameter = parameters[i]; var castExpression = $"global::TUnit.Core.Helpers.CastHelper.Cast<{parameter.Type.GloballyQualified()}>({argumentsArrayName}[{i}])"; allArguments.Add(castExpression); } - + return string.Join(", ", allArguments); } - + /// /// Generates argument access for a method with possible params array, given a specific argument count. /// @@ -83,10 +52,10 @@ public static List GenerateArgumentAccessWithParams(IList 0 && parameters[parameters.Count - 1].IsParams; - + if (!hasParams) { // No params array - just cast each argument @@ -114,7 +83,7 @@ public static List GenerateArgumentAccessWithParams(IList GenerateArgumentAccessWithParams(IList GenerateArgumentAccessWithParams(IList c.Parameters.Length > 0); } -} \ No newline at end of file +} diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TypedConstantParser.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TypedConstantParser.cs index 571c87a76c..fc38c7e5e2 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TypedConstantParser.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TypedConstantParser.cs @@ -1,8 +1,5 @@ using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; using TUnit.Core.SourceGenerator.CodeGenerators.Formatting; -using TUnit.Core.SourceGenerator.Extensions; namespace TUnit.Core.SourceGenerator.CodeGenerators.Helpers; @@ -10,66 +7,9 @@ public static class TypedConstantParser { private static readonly TypedConstantFormatter _formatter = new(); - public static string GetFullyQualifiedTypeNameFromTypedConstantValue(TypedConstant typedConstant) - { - if (typedConstant.Kind == TypedConstantKind.Type) - { - var type = (INamedTypeSymbol) typedConstant.Value!; - return type.GloballyQualified(); - } - - if (typedConstant.Kind == TypedConstantKind.Enum) - { - return typedConstant.Type!.GloballyQualified(); - } - - if (typedConstant.Kind is not TypedConstantKind.Error and not TypedConstantKind.Array) - { - return $"global::{typedConstant.Value!.GetType().FullName}"; - } - - return typedConstant.Type!.GloballyQualified(); - } - public static string GetRawTypedConstantValue(TypedConstant typedConstant, ITypeSymbol? targetType = null) { // Use the formatter for consistent handling return _formatter.FormatForCode(typedConstant, targetType); } - - public static string FormatPrimitive(object? value) - { - // Check for special floating-point values first - var specialFloatValue = SpecialFloatingPointValuesHelper.TryFormatSpecialFloatingPointValue(value); - if (specialFloatValue != null) - { - return specialFloatValue; - } - - switch (value) - { - case string s: - return SymbolDisplay.FormatLiteral(s, quote: true); - case char c: - return SymbolDisplay.FormatLiteral(c, quote: true); - case bool b: - return b ? "true" : "false"; - case null: - return "null"; - // Use InvariantCulture for numeric types to ensure consistent formatting - case double d: - return d.ToString(System.Globalization.CultureInfo.InvariantCulture) + "d"; - case float f: - return f.ToString(System.Globalization.CultureInfo.InvariantCulture) + "f"; - case decimal dec: - return dec.ToString(System.Globalization.CultureInfo.InvariantCulture) + "m"; - default: - // For other numeric types, use InvariantCulture - if (value is IFormattable formattable) - { - return formattable.ToString(null, System.Globalization.CultureInfo.InvariantCulture); - } - return value.ToString() ?? "null"; - } - } } diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs index ab2323c2a8..3485736116 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs @@ -1,5 +1,4 @@ -using System.Collections.Immutable; -using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using TUnit.Core.SourceGenerator.CodeGenerators.Helpers; using TUnit.Core.SourceGenerator.Extensions; diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/AssemblyHooksWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/AssemblyHooksWriter.cs deleted file mode 100644 index 92c54b64e3..0000000000 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/AssemblyHooksWriter.cs +++ /dev/null @@ -1,59 +0,0 @@ -using TUnit.Core.SourceGenerator.CodeGenerators.Helpers; -using TUnit.Core.SourceGenerator.Enums; -using TUnit.Core.SourceGenerator.Models; - -namespace TUnit.Core.SourceGenerator.CodeGenerators.Writers.Hooks; - -public static class AssemblyHooksWriter -{ - public static void Execute(ICodeWriter sourceBuilder, HooksDataModel? model) - { - if (model is null) - { - return; - } - - if (model.HookLocationType == HookLocationType.Before) - { - sourceBuilder.Append("new global::TUnit.Core.Hooks.BeforeAssemblyHookMethod"); - } - else - { - sourceBuilder.Append("new global::TUnit.Core.Hooks.AfterAssemblyHookMethod"); - } - - sourceBuilder.Append("{"); - sourceBuilder.Append("MethodInfo = "); - SourceInformationWriter.GenerateMethodInformation(sourceBuilder, model.Context.SemanticModel.Compilation, model.ClassType, model.Method, null, ','); - - sourceBuilder.Append($"Body = (context, cancellationToken) => AsyncConvert.Convert(() => {model.FullyQualifiedTypeName}.{model.MethodName}({GetArgs(model)})),"); - - sourceBuilder.Append($"HookExecutor = {HookExecutorHelper.GetHookExecutor(model.HookExecutor)},"); - sourceBuilder.Append($"Order = {model.Order},"); - sourceBuilder.Append($"RegistrationIndex = global::TUnit.Core.HookRegistrationIndices.GetNext{(model.HookLocationType == HookLocationType.Before ? "Before" : "After")}{(model.IsEveryHook ? "Every" : "")}AssemblyHookIndex(),"); - sourceBuilder.Append($"""FilePath = @"{model.FilePath}","""); - sourceBuilder.Append($"LineNumber = {model.LineNumber},"); - - sourceBuilder.Append("},"); - } - - private static string GetArgs(HooksDataModel model) - { - List args = []; - - foreach (var type in model.ParameterTypes) - { - if (type == WellKnownFullyQualifiedClassNames.AssemblyHookContext.WithGlobalPrefix) - { - args.Add("context"); - } - - if (type == WellKnownFullyQualifiedClassNames.CancellationToken.WithGlobalPrefix) - { - args.Add("cancellationToken"); - } - } - - return string.Join(", ", args); - } -} diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/BaseHookWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/BaseHookWriter.cs deleted file mode 100644 index f9017e7e8d..0000000000 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/BaseHookWriter.cs +++ /dev/null @@ -1,38 +0,0 @@ -using TUnit.Core.SourceGenerator.Enums; -using TUnit.Core.SourceGenerator.Models; - -namespace TUnit.Core.SourceGenerator.CodeGenerators.Writers.Hooks; - -public class BaseHookWriter -{ - protected static string GetArgs(HooksDataModel model) - { - List args = []; - - var expectedType = model.HookLevel switch - { - "TUnit.Core.HookType.Test" => WellKnownFullyQualifiedClassNames.TestContext, - "TUnit.Core.HookType.Class" => WellKnownFullyQualifiedClassNames.ClassHookContext, - "TUnit.Core.HookType.Assembly" => WellKnownFullyQualifiedClassNames.AssemblyHookContext, - "TUnit.Core.HookType.TestSession" => WellKnownFullyQualifiedClassNames.TestSessionContext, - "TUnit.Core.HookType.TestDiscovery" when model.HookLocationType == HookLocationType.Before => WellKnownFullyQualifiedClassNames.BeforeTestDiscoveryContext, - "TUnit.Core.HookType.TestDiscovery" when model.HookLocationType == HookLocationType.After => WellKnownFullyQualifiedClassNames.TestDiscoveryContext, - _ => throw new ArgumentOutOfRangeException() - }; - - foreach (var type in model.ParameterTypes) - { - if (type == expectedType.WithGlobalPrefix) - { - args.Add("context"); - } - - if (type == WellKnownFullyQualifiedClassNames.CancellationToken.WithGlobalPrefix) - { - args.Add("cancellationToken"); - } - } - - return string.Join(", ", args); - } -} diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/ClassHooksWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/ClassHooksWriter.cs deleted file mode 100644 index e9c4a5ff8c..0000000000 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/ClassHooksWriter.cs +++ /dev/null @@ -1,54 +0,0 @@ -using TUnit.Core.SourceGenerator.CodeGenerators.Helpers; -using TUnit.Core.SourceGenerator.Enums; -using TUnit.Core.SourceGenerator.Models; - -namespace TUnit.Core.SourceGenerator.CodeGenerators.Writers.Hooks; - -public static class ClassHooksWriter -{ - public static void Execute(ICodeWriter sourceBuilder, HooksDataModel model) - { - if (model.HookLocationType == HookLocationType.Before) - { - sourceBuilder.Append("new global::TUnit.Core.Hooks.BeforeClassHookMethod"); - } - else - { - sourceBuilder.Append("new global::TUnit.Core.Hooks.AfterClassHookMethod"); - } - - sourceBuilder.Append("{"); - sourceBuilder.Append("MethodInfo = "); - SourceInformationWriter.GenerateMethodInformation(sourceBuilder, model.Context.SemanticModel.Compilation, model.ClassType, model.Method, null, ','); - - sourceBuilder.Append($"Body = (context, cancellationToken) => AsyncConvert.Convert(() => {model.FullyQualifiedTypeName}.{model.MethodName}({GetArgs(model)})),"); - - sourceBuilder.Append($"HookExecutor = {HookExecutorHelper.GetHookExecutor(model.HookExecutor)},"); - sourceBuilder.Append($"Order = {model.Order},"); - sourceBuilder.Append($"RegistrationIndex = global::TUnit.Core.HookRegistrationIndices.GetNext{(model.HookLocationType == HookLocationType.Before ? "Before" : "After")}{(model.IsEveryHook ? "Every" : "")}ClassHookIndex(),"); - sourceBuilder.Append($"""FilePath = @"{model.FilePath}","""); - sourceBuilder.Append($"LineNumber = {model.LineNumber},"); - - sourceBuilder.Append("},"); - } - - private static string GetArgs(HooksDataModel model) - { - List args = []; - - foreach (var type in model.ParameterTypes) - { - if (type == WellKnownFullyQualifiedClassNames.ClassHookContext.WithGlobalPrefix) - { - args.Add("context"); - } - - if (type == WellKnownFullyQualifiedClassNames.CancellationToken.WithGlobalPrefix) - { - args.Add("cancellationToken"); - } - } - - return string.Join(", ", args); - } -} diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/GlobalTestHooksWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/GlobalTestHooksWriter.cs deleted file mode 100644 index 365fc2bec3..0000000000 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/GlobalTestHooksWriter.cs +++ /dev/null @@ -1,102 +0,0 @@ -using TUnit.Core.SourceGenerator.CodeGenerators.Helpers; -using TUnit.Core.SourceGenerator.Enums; -using TUnit.Core.SourceGenerator.Models; - -namespace TUnit.Core.SourceGenerator.CodeGenerators.Writers.Hooks; - -public static class GlobalTestHooksWriter -{ - public static void Execute(ICodeWriter sourceBuilder, HooksDataModel model) - { - sourceBuilder.Append($"new {GetClassType(model.HookLevel, model.HookLocationType)}"); - sourceBuilder.Append("{"); - - sourceBuilder.Append("MethodInfo = "); - SourceInformationWriter.GenerateMethodInformation(sourceBuilder, model.Context.SemanticModel.Compilation, model.ClassType, model.Method, null, ','); - - sourceBuilder.Append($"Body = (context, cancellationToken) => AsyncConvert.Convert(() => {model.FullyQualifiedTypeName}.{model.MethodName}({GetArgs(model, model.HookLocationType)})),"); - - sourceBuilder.Append($"HookExecutor = {HookExecutorHelper.GetHookExecutor(model.HookExecutor)},"); - sourceBuilder.Append($"Order = {model.Order},"); - sourceBuilder.Append($"RegistrationIndex = {GetRegistrationIndexMethod(model.HookLevel, model.HookLocationType, model.IsEveryHook)},"); - sourceBuilder.Append($"""FilePath = @"{model.FilePath}","""); - sourceBuilder.Append($"LineNumber = {model.LineNumber},"); - - sourceBuilder.Append("},"); - } - - private static string GetClassType(string hookType, HookLocationType hookLocationType) - { - if (hookLocationType == HookLocationType.Before) - { - return hookType switch - { - "TUnit.Core.HookType.Test" => "global::TUnit.Core.Hooks.BeforeTestHookMethod", - "TUnit.Core.HookType.Class" => "global::TUnit.Core.Hooks.BeforeClassHookMethod", - "TUnit.Core.HookType.Assembly" => "global::TUnit.Core.Hooks.BeforeAssemblyHookMethod", - "TUnit.Core.HookType.TestSession" => "global::TUnit.Core.Hooks.BeforeTestSessionHookMethod", - "TUnit.Core.HookType.TestDiscovery" => "global::TUnit.Core.Hooks.BeforeTestDiscoveryHookMethod", - _ => throw new ArgumentOutOfRangeException(nameof(hookType), hookType, null) - }; - } - - return hookType switch - { - "TUnit.Core.HookType.Test" => "global::TUnit.Core.Hooks.AfterTestHookMethod", - "TUnit.Core.HookType.Class" => "global::TUnit.Core.Hooks.AfterClassHookMethod", - "TUnit.Core.HookType.Assembly" => "global::TUnit.Core.Hooks.AfterAssemblyHookMethod", - "TUnit.Core.HookType.TestSession" => "global::TUnit.Core.Hooks.AfterTestSessionHookMethod", - "TUnit.Core.HookType.TestDiscovery" => "global::TUnit.Core.Hooks.AfterTestDiscoveryHookMethod", - _ => throw new ArgumentOutOfRangeException(nameof(hookType), hookType, null) - }; - } - - private static string GetArgs(HooksDataModel model, HookLocationType hookLocationType) - { - List args = []; - - var expectedType = model.HookLevel switch - { - "TUnit.Core.HookType.Test" => WellKnownFullyQualifiedClassNames.TestContext, - "TUnit.Core.HookType.Class" => WellKnownFullyQualifiedClassNames.ClassHookContext, - "TUnit.Core.HookType.Assembly" => WellKnownFullyQualifiedClassNames.AssemblyHookContext, - "TUnit.Core.HookType.TestSession" => WellKnownFullyQualifiedClassNames.TestSessionContext, - "TUnit.Core.HookType.TestDiscovery" when hookLocationType == HookLocationType.Before => WellKnownFullyQualifiedClassNames.BeforeTestDiscoveryContext, - "TUnit.Core.HookType.TestDiscovery" when hookLocationType == HookLocationType.After => WellKnownFullyQualifiedClassNames.TestDiscoveryContext, - _ => throw new ArgumentOutOfRangeException() - }; - - foreach (var type in model.ParameterTypes) - { - if (type == expectedType.WithGlobalPrefix) - { - args.Add("context"); - } - - if (type == WellKnownFullyQualifiedClassNames.CancellationToken.WithGlobalPrefix) - { - args.Add("cancellationToken"); - } - } - - return string.Join(", ", args); - } - - private static string GetRegistrationIndexMethod(string hookType, HookLocationType hookLocationType, bool isEveryHook) - { - var hookTypeSimple = hookType switch - { - "TUnit.Core.HookType.Test" => "Test", - "TUnit.Core.HookType.Class" => "Class", - "TUnit.Core.HookType.Assembly" => "Assembly", - "TUnit.Core.HookType.TestSession" => "TestSession", - "TUnit.Core.HookType.TestDiscovery" => "TestDiscovery", - _ => throw new ArgumentOutOfRangeException(nameof(hookType), hookType, null) - }; - - var prefix = hookLocationType == HookLocationType.Before ? "Before" : "After"; - var everyPart = isEveryHook ? "Every" : ""; - - return $"global::TUnit.Core.HookRegistrationIndices.GetNext{prefix}{everyPart}{hookTypeSimple}HookIndex()"; - } -} diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/SourceInformationWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/SourceInformationWriter.cs index 80fdd3135a..460410e5ea 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/SourceInformationWriter.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/SourceInformationWriter.cs @@ -1,7 +1,6 @@ using System.Collections.Immutable; using Microsoft.CodeAnalysis; using TUnit.Core.SourceGenerator.Enums; -using TUnit.Core.SourceGenerator.Extensions; using TUnit.Core.SourceGenerator.Utilities; namespace TUnit.Core.SourceGenerator.CodeGenerators.Writers; diff --git a/TUnit.Core.SourceGenerator/Extensions/AttributeDataExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/AttributeDataExtensions.cs index 23edbd23f4..68348a2c4b 100644 --- a/TUnit.Core.SourceGenerator/Extensions/AttributeDataExtensions.cs +++ b/TUnit.Core.SourceGenerator/Extensions/AttributeDataExtensions.cs @@ -52,24 +52,4 @@ public static bool IsTypedDataSourceAttribute(this AttributeData? attributeData) return typedInterface?.TypeArguments.FirstOrDefault(); } - - public static bool IsNonGlobalHook(this AttributeData attributeData, Compilation compilation) - { - // Cache type symbols to avoid repeated GetTypeByMetadataName calls - var beforeAttribute = compilation.GetTypeByMetadataName(WellKnownFullyQualifiedClassNames.BeforeAttribute.WithoutGlobalPrefix); - var afterAttribute = compilation.GetTypeByMetadataName(WellKnownFullyQualifiedClassNames.AfterAttribute.WithoutGlobalPrefix); - - return SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, beforeAttribute) - || SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, afterAttribute); - } - - public static bool IsGlobalHook(this AttributeData attributeData, Compilation compilation) - { - // Cache type symbols to avoid repeated GetTypeByMetadataName calls - var beforeEveryAttribute = compilation.GetTypeByMetadataName(WellKnownFullyQualifiedClassNames.BeforeEveryAttribute.WithoutGlobalPrefix); - var afterEveryAttribute = compilation.GetTypeByMetadataName(WellKnownFullyQualifiedClassNames.AfterEveryAttribute.WithoutGlobalPrefix); - - return SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, beforeEveryAttribute) - || SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, afterEveryAttribute); - } } diff --git a/TUnit.Core.SourceGenerator/Extensions/CompilationExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/CompilationExtensions.cs deleted file mode 100644 index c561359edc..0000000000 --- a/TUnit.Core.SourceGenerator/Extensions/CompilationExtensions.cs +++ /dev/null @@ -1,42 +0,0 @@ -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using TUnit.Core.SourceGenerator.Extensions; - -namespace TUnit.Analyzers.Extensions; - -public static class CompilationExtensions -{ - public static bool HasImplicitConversionOrGenericParameter(this Compilation compilation, ITypeSymbol? argumentType, - ITypeSymbol? parameterType) - { - if (parameterType?.IsGenericDefinition() == false) - { - if (argumentType is null) - { - return false; - } - - var conversion = compilation.ClassifyConversion(argumentType, parameterType); - - return conversion.IsImplicit || conversion.IsNumeric; - } - - if (parameterType is INamedTypeSymbol { IsGenericType: true, TypeArguments: [{ TypeKind: TypeKind.TypeParameter }] } namedType) - { - // `IEnumerable<>` - if (argumentType is IArrayTypeSymbol { ElementType: { } elementType }) - { - var specializedSuper = namedType.OriginalDefinition.Construct(elementType); - return compilation.HasImplicitConversion(argumentType, specializedSuper); - } - - if (argumentType is INamedTypeSymbol { IsGenericType: true, TypeArguments: [{ } genericArgument] }) - { - var specializedSuper = namedType.OriginalDefinition.Construct(genericArgument); - return compilation.HasImplicitConversion(argumentType, specializedSuper); - } - } - - return compilation.HasImplicitConversion(argumentType?.OriginalDefinition, parameterType?.OriginalDefinition); - } -} diff --git a/TUnit.Core.SourceGenerator/Extensions/MethodExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/MethodExtensions.cs index c82b1199cd..d28002c83e 100644 --- a/TUnit.Core.SourceGenerator/Extensions/MethodExtensions.cs +++ b/TUnit.Core.SourceGenerator/Extensions/MethodExtensions.cs @@ -4,7 +4,13 @@ namespace TUnit.Core.SourceGenerator.Extensions; public static class MethodExtensions { - public static AttributeData? GetTestAttribute(this IMethodSymbol methodSymbol) + public static AttributeData GetRequiredTestAttribute(this IMethodSymbol methodSymbol) + { + return GetTestAttribute(methodSymbol) ?? + throw new ArgumentException($"No test attribute found on {methodSymbol.ContainingType.Name}.{methodSymbol.Name}"); + } + + private static AttributeData? GetTestAttribute(IMethodSymbol methodSymbol) { var attributes = methodSymbol.GetAttributes(); @@ -17,20 +23,4 @@ public static class MethodExtensions .FirstOrDefault(x => x.AttributeClass?.BaseType?.GloballyQualified() == WellKnownFullyQualifiedClassNames.BaseTestAttribute.WithGlobalPrefix); } - - public static AttributeData GetRequiredTestAttribute(this IMethodSymbol methodSymbol) - { - return GetTestAttribute(methodSymbol) ?? - throw new ArgumentException($"No test attribute found on {methodSymbol.ContainingType.Name}.{methodSymbol.Name}"); - } - - public static bool IsTest(this IMethodSymbol methodSymbol) - { - return methodSymbol.GetTestAttribute() != null; - } - - public static bool IsHook(this IMethodSymbol methodSymbol, Compilation compilation) - { - return methodSymbol.GetAttributes().Any(x => x.IsNonGlobalHook(compilation) || x.IsGlobalHook(compilation)); - } } diff --git a/TUnit.Core.SourceGenerator/Extensions/SymbolExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/SymbolExtensions.cs index 971586cc2c..c9d419ff67 100644 --- a/TUnit.Core.SourceGenerator/Extensions/SymbolExtensions.cs +++ b/TUnit.Core.SourceGenerator/Extensions/SymbolExtensions.cs @@ -1,5 +1,4 @@ -using System.Collections.Generic; -using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis; namespace TUnit.Core.SourceGenerator.Extensions; @@ -28,7 +27,7 @@ public static bool IsConst(this ISymbol? symbol, out object? constantValue) constantValue = null; return false; } - + /// /// Creates an IEqualityComparer for tuples that uses SymbolEqualityComparer for symbol comparison /// @@ -36,21 +35,21 @@ public static bool IsConst(this ISymbol? symbol, out object? constantValue) { return new TupleSymbolComparer(comparer); } - + private class TupleSymbolComparer : IEqualityComparer<(INamedTypeSymbol, string)> { private readonly IEqualityComparer _symbolComparer; - + public TupleSymbolComparer(IEqualityComparer symbolComparer) { _symbolComparer = symbolComparer; } - + public bool Equals((INamedTypeSymbol, string) x, (INamedTypeSymbol, string) y) { return _symbolComparer.Equals(x.Item1, y.Item1) && x.Item2 == y.Item2; } - + public int GetHashCode((INamedTypeSymbol, string) obj) { var hash1 = _symbolComparer.GetHashCode(obj.Item1); diff --git a/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs index d9c530c89d..ba44c376b4 100644 --- a/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs +++ b/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs @@ -1,33 +1,11 @@ -using System.Collections.Immutable; -using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.CodeAnalysis; using System.Text; using Microsoft.CodeAnalysis; -using TUnit.Analyzers.Extensions; -using TUnit.Core.SourceGenerator.CodeGenerators.Helpers; namespace TUnit.Core.SourceGenerator.Extensions; public static class TypeExtensions { - private static readonly Dictionary ReservedTypeKeywords = new() - { - { "System.Boolean", "bool" }, - { "System.Byte", "byte" }, - { "System.SByte", "sbyte" }, - { "System.Char", "char" }, - { "System.Decimal", "decimal" }, - { "System.Double", "double" }, - { "System.Single", "float" }, - { "System.Int32", "int" }, - { "System.UInt32", "uint" }, - { "System.Int64", "long" }, - { "System.UInt64", "ulong" }, - { "System.Int16", "short" }, - { "System.UInt16", "ushort" }, - { "System.Object", "object" }, - { "System.String", "string" } - }; - public static string GetMetadataName(this Type type) { return $"{type.Namespace}.{type.Name}"; diff --git a/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs b/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs index e6675284c4..6d2e909489 100644 --- a/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs @@ -1,7 +1,6 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using TUnit.Core.SourceGenerator.CodeGenerators.Helpers; -using TUnit.Core.SourceGenerator.CodeGenerators.Writers; using TUnit.Core.SourceGenerator.Extensions; using TUnit.Core.SourceGenerator.Models; using TUnit.Core.SourceGenerator.Models.Extracted; diff --git a/TUnit.Core.SourceGenerator/Helpers/GenericTypeInference.cs b/TUnit.Core.SourceGenerator/Helpers/GenericTypeInference.cs deleted file mode 100644 index 091a767e0d..0000000000 --- a/TUnit.Core.SourceGenerator/Helpers/GenericTypeInference.cs +++ /dev/null @@ -1,359 +0,0 @@ -using System.Collections.Immutable; -using Microsoft.CodeAnalysis; -using TUnit.Core.SourceGenerator.CodeGenerators.Helpers; -using TUnit.Core.SourceGenerator.Extensions; - -namespace TUnit.Core.SourceGenerator.Helpers; - -/// -/// Provides generic type inference capabilities for test methods with data sources -/// -internal static class GenericTypeInference -{ - private static ImmutableArray? TryInferFromTypedDataSources(IMethodSymbol method, ImmutableArray attributes) - { - foreach (var attribute in attributes) - { - if (attribute.AttributeClass == null) - { - continue; - } - - // Check if this is a typed data source (inherits from AsyncDataSourceGeneratorAttribute or DataSourceGeneratorAttribute) - var baseType = GetTypedDataSourceBase(attribute.AttributeClass); - if (baseType is { TypeArguments.Length: > 0 }) - { - // For single type parameter methods, use the first type argument - if (method.TypeParameters.Length == 1) - { - return ImmutableArray.Create(baseType.TypeArguments[0]); - } - - // For multiple type parameters, match by parameter position if possible - if (baseType.TypeArguments.Length >= method.TypeParameters.Length) - { - return baseType.TypeArguments.Take(method.TypeParameters.Length).ToImmutableArray(); - } - } - } - - return null; - } - - private static INamedTypeSymbol? GetTypedDataSourceBase(INamedTypeSymbol attributeClass) - { - var current = attributeClass; - while (current != null) - { - // Check if it's a generic base class - if (current.IsGenericType) - { - var name = current.Name; - if (name.Contains("DataSourceGeneratorAttribute") || - name.Contains("AsyncDataSourceGeneratorAttribute")) - { - return current; - } - } - - current = current.BaseType; - } - - return null; - } - - private static ImmutableArray? TryInferFromTypeInferringAttributes(IMethodSymbol method) - { - var inferredTypes = new List(); - - // Look at each parameter to see if it has attributes that implement IInfersType - foreach (var parameter in method.Parameters) - { - if (parameter.Type is ITypeParameterSymbol typeParam) - { - // Check if this parameter has attributes that implement IInfersType - foreach (var attr in parameter.GetAttributes()) - { - if (attr.AttributeClass != null) - { - // Look for IInfersType in the attribute's interfaces - var infersTypeInterface = attr.AttributeClass.AllInterfaces - .FirstOrDefault(i => i.GloballyQualifiedNonGeneric() == "global::TUnit.Core.Interfaces.IInfersType" && - i.IsGenericType && - i.TypeArguments.Length == 1); - - if (infersTypeInterface != null) - { - // Get the type argument from IInfersType - var inferredType = infersTypeInterface.TypeArguments[0]; - - // Find the index of this type parameter - var typeParamIndex = -1; - for (var i = 0; i < method.TypeParameters.Length; i++) - { - if (method.TypeParameters[i].Name == typeParam.Name) - { - typeParamIndex = i; - break; - } - } - - if (typeParamIndex >= 0) - { - // Make sure we have enough slots - while (inferredTypes.Count <= typeParamIndex) - { - inferredTypes.Add(null!); - } - inferredTypes[typeParamIndex] = inferredType; - } - } - } - } - } - } - - // Remove any null entries and check if we have all types - inferredTypes.RemoveAll(t => t == null); - - return inferredTypes.Count == method.TypeParameters.Length - ? inferredTypes.ToImmutableArray() - : null; - } - - private static ITypeSymbol? InferTypeFromValue(TypedConstant value) - { - if (value.IsNull) - { - return null; - } - - // The type of the constant value tells us what T should be - return value.Type; - } - - private static ImmutableArray? TryInferFromMethodDataSource(IMethodSymbol testMethod, ImmutableArray attributes) - { - var methodDataSourceAttributes = attributes - .Where(a => a.AttributeClass?.Name == "MethodDataSourceAttribute") - .ToList(); - - if (!methodDataSourceAttributes.Any()) - { - return null; - } - - foreach (var attr in methodDataSourceAttributes) - { - if (attr.ConstructorArguments.Length > 0 && attr.ConstructorArguments[0].Value is string methodName) - { - // Find the data source method - var dataSourceMethod = testMethod.ContainingType.GetMembers(methodName) - .OfType() - .FirstOrDefault(); - - if (dataSourceMethod is { ReturnType: INamedTypeSymbol { IsGenericType: true, TypeArguments.Length: > 0 } namedType }) - // Analyze the return type to extract generic types - // Handle IEnumerable> - { - var funcType = namedType.TypeArguments[0]; - if (funcType is INamedTypeSymbol { Name: "Func", TypeArguments.Length: > 0 } funcNamedType) - { - var tupleType = funcNamedType.TypeArguments[0]; - if (tupleType is INamedTypeSymbol { IsTupleType: true } tupleNamedType) - { - // Extract types from tuple elements - var inferredTypes = InferTypesFromTupleElements(testMethod, tupleNamedType); - if (inferredTypes != null) - { - return inferredTypes; - } - } - } - } - } - } - - return null; - } - - private static ImmutableArray? InferTypesFromTupleElements(IMethodSymbol testMethod, INamedTypeSymbol tupleType) - { - var inferredTypes = new ITypeSymbol[testMethod.TypeParameters.Length]; - var tupleElements = tupleType.TupleElements; - - // Map tuple elements to method parameters - for (var i = 0; i < testMethod.Parameters.Length && i < tupleElements.Length; i++) - { - var parameter = testMethod.Parameters[i]; - var tupleElement = tupleElements[i]; - - if (parameter.Type is ITypeParameterSymbol typeParam) - { - // Find the index of this type parameter - var typeParamIndex = -1; - for (var j = 0; j < testMethod.TypeParameters.Length; j++) - { - if (testMethod.TypeParameters[j].Name == typeParam.Name) - { - typeParamIndex = j; - break; - } - } - - if (typeParamIndex >= 0) - { - // For generic types like IEnumerable, extract T - var elementType = tupleElement.Type; - if (elementType is INamedTypeSymbol { IsGenericType: true, TypeArguments.Length: > 0 } namedElementType) - { - // For IEnumerable, we want int - inferredTypes[typeParamIndex] = namedElementType.TypeArguments[0]; - } - else - { - // For direct types - inferredTypes[typeParamIndex] = elementType; - } - } - } - else if (parameter.Type is INamedTypeSymbol { IsGenericType: true } paramNamedType) - { - // Handle complex generic parameters like Func - // This is more complex and would need deeper analysis - var tupleElementType = tupleElement.Type; - if (tupleElementType is INamedTypeSymbol { Name: "Func" } funcType) - { - // Match type arguments between parameter type and tuple element type - for (var j = 0; j < funcType.TypeArguments.Length && j < paramNamedType.TypeArguments.Length; j++) - { - var paramTypeArg = paramNamedType.TypeArguments[j]; - if (paramTypeArg is ITypeParameterSymbol funcTypeParam) - { - var typeParamIndex = -1; - for (var k = 0; k < testMethod.TypeParameters.Length; k++) - { - if (testMethod.TypeParameters[k].Name == funcTypeParam.Name) - { - typeParamIndex = k; - break; - } - } - - if (typeParamIndex >= 0) - { - inferredTypes[typeParamIndex] = funcType.TypeArguments[j]; - } - } - } - } - } - } - - // Check if we have all required types - if (inferredTypes.All(t => t != null)) - { - return inferredTypes.ToImmutableArray(); - } - - return null; - } - - /// - /// Gets all unique generic type combinations for a method based on its data sources - /// - public static ImmutableArray> GetAllGenericTypeCombinations( - IMethodSymbol method, - ImmutableArray attributes) - { - var combinations = new List>(); - - // For Arguments attributes, each one might produce a different type combination - var argumentsAttributes = attributes - .Where(a => a.AttributeClass?.Name == "ArgumentsAttribute") - .ToList(); - - foreach (var args in argumentsAttributes) - { - var types = InferTypesFromSingleArguments(method, args); - if (types != null && !combinations.Any(c => TypeArraysEqual(c, types.Value))) - { - combinations.Add(types.Value); - } - } - - // For typed data sources, we typically get one type combination - var typedSourceTypes = TryInferFromTypedDataSources(method, attributes); - if (typedSourceTypes != null && !combinations.Any(c => TypeArraysEqual(c, typedSourceTypes.Value))) - { - combinations.Add(typedSourceTypes.Value); - } - - // For parameter attributes that implement IInfersType - var inferredTypes = TryInferFromTypeInferringAttributes(method); - if (inferredTypes != null && !combinations.Any(c => TypeArraysEqual(c, inferredTypes.Value))) - { - combinations.Add(inferredTypes.Value); - } - - // For MethodDataSource attributes - var methodDataSourceTypes = TryInferFromMethodDataSource(method, attributes); - if (methodDataSourceTypes != null && !combinations.Any(c => TypeArraysEqual(c, methodDataSourceTypes.Value))) - { - combinations.Add(methodDataSourceTypes.Value); - } - - return combinations.ToImmutableArray(); - } - - private static ImmutableArray? InferTypesFromSingleArguments(IMethodSymbol method, AttributeData args) - { - if (!method.IsGenericMethod || args.ConstructorArguments.Length == 0) - { - return null; - } - - var inferredTypes = new List(); - - for (var i = 0; i < method.TypeParameters.Length && i < method.Parameters.Length; i++) - { - var parameter = method.Parameters[i]; - - if (parameter.Type is ITypeParameterSymbol) - { - if (i < args.ConstructorArguments.Length) - { - var argValue = args.ConstructorArguments[i]; - var inferredType = InferTypeFromValue(argValue); - - if (inferredType != null) - { - inferredTypes.Add(inferredType); - } - } - } - } - - return inferredTypes.Count == method.TypeParameters.Length - ? inferredTypes.ToImmutableArray() - : null; - } - - private static bool TypeArraysEqual(ImmutableArray a, ImmutableArray b) - { - if (a.Length != b.Length) - { - return false; - } - - for (var i = 0; i < a.Length; i++) - { - if (!SymbolEqualityComparer.Default.Equals(a[i], b[i])) - { - return false; - } - } - - return true; - } -} diff --git a/TUnit.Core.SourceGenerator/Models/Extracted/PropertyInjectionModel.cs b/TUnit.Core.SourceGenerator/Models/Extracted/PropertyInjectionModel.cs index 28fe2326b7..0494c65d1d 100644 --- a/TUnit.Core.SourceGenerator/Models/Extracted/PropertyInjectionModel.cs +++ b/TUnit.Core.SourceGenerator/Models/Extracted/PropertyInjectionModel.cs @@ -1,5 +1,3 @@ -using TUnit.Core.SourceGenerator.Models; - namespace TUnit.Core.SourceGenerator.Models.Extracted; /// diff --git a/TUnit.Core.SourceGenerator/TUnit.Core.SourceGenerator.csproj b/TUnit.Core.SourceGenerator/TUnit.Core.SourceGenerator.csproj index 4340cfc618..f6efb79d95 100644 --- a/TUnit.Core.SourceGenerator/TUnit.Core.SourceGenerator.csproj +++ b/TUnit.Core.SourceGenerator/TUnit.Core.SourceGenerator.csproj @@ -24,4 +24,12 @@ + + + + + + + + diff --git a/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs b/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs index fd2cc9e652..db6b93fda3 100644 --- a/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs +++ b/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs @@ -1,6 +1,5 @@ using System.Collections.Immutable; using Microsoft.CodeAnalysis; -using TUnit.Core.SourceGenerator.CodeGenerators; using TUnit.Core.SourceGenerator.Extensions; namespace TUnit.Core.SourceGenerator.Utilities;