Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 2 additions & 309 deletions TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ private static void GenerateTestMetadata(CodeWriter writer, TestMethodMetadata t

if (hasTypedDataSource || hasGenerateGenericTest || testMethod.IsGenericMethod || hasClassArguments || hasTypedDataSourceForGenericType || hasMethodArgumentsForGenericType || hasMethodDataSourceForGenericType || hasClassDataSources)
{
GenerateGenericTestWithConcreteTypes(writer, testMethod, className, uniqueClassName);
GenerateGenericTestWithConcreteTypes(writer, testMethod, className);
}
else
{
Expand Down Expand Up @@ -350,138 +350,6 @@ private static void GenerateTestMetadata(CodeWriter writer, TestMethodMetadata t
GenerateModuleInitializer(writer, testMethod, uniqueClassName);
}

private static void GenerateSpecificGenericInstantiation(
CodeWriter writer,
TestMethodMetadata testMethod,
string className,
string combinationGuid,
ImmutableArray<ITypeSymbol> typeArguments)
{
var methodName = testMethod.MethodSymbol.Name;
var typeArgsString = string.Join(", ", typeArguments.Select(t => t.GloballyQualified()));
var instantiatedMethodName = $"{methodName}<{typeArgsString}>";

var concreteTestMethod = new TestMethodMetadata
{
MethodSymbol = testMethod.MethodSymbol,
TypeSymbol = testMethod.TypeSymbol,
FilePath = testMethod.FilePath,
LineNumber = testMethod.LineNumber,
TestAttribute = testMethod.TestAttribute,
Context = testMethod.Context,
MethodSyntax = testMethod.MethodSyntax,
IsGenericType = testMethod.IsGenericType,
IsGenericMethod = false, // We're creating a concrete instantiation
MethodAttributes = testMethod.MethodAttributes
};

writer.AppendLine($"// Generated instantiation for {instantiatedMethodName}");
writer.AppendLine("{");
writer.Indent();

writer.AppendLine($"var metadata = new global::TUnit.Core.TestMetadata<{className}>");
writer.AppendLine("{");
writer.Indent();

writer.AppendLine($"TestName = \"{instantiatedMethodName}\",");
writer.AppendLine($"TestClassType = {GenerateTypeReference(testMethod.TypeSymbol, testMethod.IsGenericType)},");
writer.AppendLine($"TestMethodName = \"{methodName}\",");
writer.AppendLine($"GenericMethodTypeArguments = new global::System.Type[] {{ {string.Join(", ", typeArguments.Select(t => $"typeof({t.GloballyQualified()})"))}}},");

GenerateMetadata(writer, concreteTestMethod);

if (testMethod.IsGenericType)
{
GenerateGenericTypeInfo(writer, testMethod.TypeSymbol);
}

GenerateAotFriendlyInvokers(writer, testMethod, className, typeArguments);

writer.AppendLine($"FilePath = @\"{(testMethod.FilePath ?? "").Replace("\\", "\\\\")}\",");
writer.AppendLine($"LineNumber = {testMethod.LineNumber},");

writer.Unindent();
writer.AppendLine("};");

writer.AppendLine("metadata.TestSessionId = testSessionId;");
writer.AppendLine("yield return metadata;");

writer.Unindent();
writer.AppendLine("}");
writer.AppendLine();
}

private static void GenerateAotFriendlyInvokers(
CodeWriter writer,
TestMethodMetadata testMethod,
string className,
ImmutableArray<ITypeSymbol> typeArguments)
{
var methodName = testMethod.MethodSymbol.Name;
var typeArgsString = string.Join(", ", typeArguments.Select(t => t.GloballyQualified()));
var hasCancellationToken = testMethod.MethodSymbol.Parameters.Any(p =>
p.Type.Name == "CancellationToken" &&
p.Type.ContainingNamespace?.ToString() == "System.Threading");

writer.AppendLine("InstanceFactory = static (typeArgs, args) =>");
writer.AppendLine("{");
writer.Indent();

writer.AppendLine($"return new {className}();");

writer.Unindent();
writer.AppendLine("},");

writer.AppendLine("InvokeTypedTest = static (instance, args, cancellationToken) =>");
writer.AppendLine("{");
writer.Indent();

// Wrap entire lambda body in try-catch to handle synchronous exceptions
writer.AppendLine("try");
writer.AppendLine("{");
writer.Indent();

// Generate direct method call with specific types (no MakeGenericMethod)
writer.AppendLine($"var typedInstance = ({className})instance;");

writer.AppendLine("var methodArgs = new object?[args.Length" + (hasCancellationToken ? " + 1" : "") + "];");
writer.AppendLine("global::System.Array.Copy(args, methodArgs, args.Length);");

if (hasCancellationToken)
{
writer.AppendLine("methodArgs[args.Length] = cancellationToken;");
}

var parameterCasts = new List<string>();
for (var i = 0; i < testMethod.MethodSymbol.Parameters.Length; i++)
{
var param = testMethod.MethodSymbol.Parameters[i];
if (param.Type.Name == "CancellationToken")
{
parameterCasts.Add("cancellationToken");
}
else
{
var paramType = ReplaceTypeParametersWithConcreteTypes(param.Type, testMethod.MethodSymbol.TypeParameters, typeArguments);
parameterCasts.Add($"({paramType.GloballyQualified()})methodArgs[{i}]!");
}
}

writer.AppendLine($"return global::TUnit.Core.AsyncConvert.Convert(() => typedInstance.{methodName}<{typeArgsString}>({string.Join(", ", parameterCasts)}));");

writer.Unindent();
writer.AppendLine("}");
writer.AppendLine("catch (global::System.Exception ex)");
writer.AppendLine("{");
writer.Indent();
writer.AppendLine("return new global::System.Threading.Tasks.ValueTask(global::System.Threading.Tasks.Task.FromException(ex));");
writer.Unindent();
writer.AppendLine("}");

writer.Unindent();
writer.AppendLine("},");
}

private static ITypeSymbol ReplaceTypeParametersWithConcreteTypes(
ITypeSymbol type,
ImmutableArray<ITypeParameterSymbol> typeParameters,
Expand Down Expand Up @@ -2062,7 +1930,6 @@ private static void GenerateConcreteTestInvoker(CodeWriter writer, TestMethodMet
writer.AppendLine("},");
}


private static void GenerateEnumerateTestDescriptors(CodeWriter writer, TestMethodMetadata testMethod, string className)
{
var methodName = testMethod.MethodSymbol.Name;
Expand Down Expand Up @@ -2387,23 +2254,6 @@ private static void GenerateModuleInitializer(CodeWriter writer, TestMethodMetad
writer.AppendLine("}");
}

private static bool IsAsyncMethod(IMethodSymbol method)
{
var returnType = method.ReturnType;

var returnTypeName = returnType.ToDisplayString();
return returnTypeName.StartsWith("System.Threading.Tasks.Task") ||
returnTypeName.StartsWith("System.Threading.Tasks.ValueTask") ||
returnTypeName.StartsWith("Task<") ||
returnTypeName.StartsWith("ValueTask<");
}

private static bool ReturnsValueTask(IMethodSymbol method)
{
var returnTypeName = method.ReturnType.ToDisplayString();
return returnTypeName.StartsWith("System.Threading.Tasks.ValueTask");
}

private enum TestReturnPattern
{
Void, // void methods
Expand Down Expand Up @@ -2669,70 +2519,6 @@ private static bool GetProceedOnFailureValue(AttributeData attributeData)
return false;
}

private static string GetDefaultValueString(IParameterSymbol parameter)
{
if (!parameter.HasExplicitDefaultValue)
{
return $"default({parameter.Type.GloballyQualified()})";
}

var defaultValue = parameter.ExplicitDefaultValue;
if (defaultValue == null)
{
return "null";
}

var type = parameter.Type;

// Handle string
if (type.SpecialType == SpecialType.System_String)
{
return $"\"{defaultValue.ToString().Replace("\\", "\\\\").Replace("\"", "\\\"")}\"";
}

// Handle char
if (type.SpecialType == SpecialType.System_Char)
{
return $"'{defaultValue}'";
}

// Handle bool
if (type.SpecialType == SpecialType.System_Boolean)
{
return defaultValue.ToString().ToLowerInvariant();
}

// Handle numeric types with proper suffixes
if (type.SpecialType == SpecialType.System_Single)
{
return $"{defaultValue}f";
}
if (type.SpecialType == SpecialType.System_Double)
{
return $"{defaultValue}d";
}
if (type.SpecialType == SpecialType.System_Decimal)
{
return $"{defaultValue}m";
}
if (type.SpecialType == SpecialType.System_Int64)
{
return $"{defaultValue}L";
}
if (type.SpecialType == SpecialType.System_UInt32)
{
return $"{defaultValue}u";
}
if (type.SpecialType == SpecialType.System_UInt64)
{
return $"{defaultValue}ul";
}

// Default for other types
return defaultValue.ToString();
}


private static bool IsMethodHiding(IMethodSymbol derivedMethod, IMethodSymbol baseMethod)
{
// Must have same name
Expand Down Expand Up @@ -3147,11 +2933,6 @@ private static void GenerateGenericParameterConstraints(CodeWriter writer, IType
writer.AppendLine("},");
}

private static bool IsGenericTypeParameter(ITypeSymbol type)
{
return type.TypeKind == TypeKind.TypeParameter;
}

private static bool ContainsGenericTypeParameter(ITypeSymbol type)
{
if (type.TypeKind == TypeKind.TypeParameter)
Expand All @@ -3175,8 +2956,7 @@ private static bool ContainsGenericTypeParameter(ITypeSymbol type)
private static void GenerateGenericTestWithConcreteTypes(
CodeWriter writer,
TestMethodMetadata testMethod,
string className,
string combinationGuid)
string className)
{
var compilation = testMethod.Context!.Value.SemanticModel.Compilation;
var methodName = testMethod.MethodSymbol.Name;
Expand Down Expand Up @@ -3775,76 +3555,6 @@ private static void GenerateGenericTestWithConcreteTypes(
writer.AppendLine("yield return genericMetadata;");
}

private static void ProcessGenerateGenericTestAttribute(
AttributeData genAttr,
TestMethodMetadata testMethod,
string className,
CodeWriter writer,
HashSet<string> processedTypeCombinations,
bool isClassLevel)
{
// Extract type arguments from the attribute
if (genAttr.ConstructorArguments.Length == 0)
{
return;
}

var typeArgs = new List<ITypeSymbol>();
foreach (var arg in genAttr.ConstructorArguments)
{
if (arg is { Kind: TypedConstantKind.Type, Value: ITypeSymbol typeSymbol })
{
typeArgs.Add(typeSymbol);
}
else if (arg.Kind == TypedConstantKind.Array)
{
foreach (var arrayElement in arg.Values)
{
if (arrayElement is { Kind: TypedConstantKind.Type, Value: ITypeSymbol arrayTypeSymbol })
{
typeArgs.Add(arrayTypeSymbol);
}
}
}
}

if (typeArgs.Count == 0)
{
return;
}

var inferredTypes = typeArgs.ToArray();
var typeKey = BuildTypeKey(inferredTypes);

// Skip if we've already processed this type combination
if (!processedTypeCombinations.Add(typeKey))
{
return;
}

// Validate constraints based on whether this is a class-level or method-level attribute
bool constraintsValid;
if (isClassLevel)
{
// For class-level [GenerateGenericTest], validate against class type constraints
constraintsValid = ValidateClassTypeConstraints(testMethod.TypeSymbol, inferredTypes);
}
else
{
// For method-level [GenerateGenericTest], validate against method type constraints
constraintsValid = ValidateTypeConstraints(testMethod.MethodSymbol, inferredTypes);
}

if (constraintsValid)
{
// Generate a concrete instantiation for this type combination
// Use the same key format as runtime: FullName ?? Name
writer.AppendLine($"[{string.Join(" + \",\" + ", inferredTypes.Select(FormatTypeForRuntimeName))}] = ");
GenerateConcreteTestMetadata(writer, testMethod, className, inferredTypes);
writer.AppendLine(",");
}
}

private static List<ITypeSymbol[]> ExtractTypeArgumentSets(List<AttributeData> attributes)
{
var result = new List<ITypeSymbol[]>();
Expand Down Expand Up @@ -4676,23 +4386,6 @@ private static void ProcessTypeForGenerics(ITypeSymbol paramType, ITypeSymbol ac
}
}

private static bool ValidateTypeConstraints(INamedTypeSymbol classType, ITypeSymbol[] typeArguments)
{
// Validate constraints for a generic class
if (!classType.IsGenericType)
{
return true;
}

var typeParams = classType.TypeParameters;
if (typeParams.Length != typeArguments.Length)
{
return false;
}

return ValidateTypeParameterConstraints(typeParams, typeArguments);
}

private static bool ValidateTypeConstraints(IMethodSymbol method, ITypeSymbol[] typeArguments)
{
// Only validate method type parameters here - class type parameters are validated separately
Expand Down
Loading
Loading