Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,37 @@ namespace TUnit.Core.SourceGenerator.CodeGenerators.Helpers;

public static class InstanceFactoryGenerator
{
/// <summary>
/// Checks if the given type has a ClassConstructor attribute on the class/base types OR at the assembly level.
/// </summary>
public static bool HasClassConstructorAttribute(INamedTypeSymbol namedTypeSymbol)
{
var hasOnClass = namedTypeSymbol.GetAttributesIncludingBaseTypes()
.Any(a => a.AttributeClass?.GloballyQualifiedNonGeneric() == WellKnownFullyQualifiedClassNames.ClassConstructorAttribute.WithGlobalPrefix);

if (hasOnClass)
{
return true;
}

return namedTypeSymbol.ContainingAssembly.GetAttributes()
.Any(a => a.AttributeClass?.GloballyQualifiedNonGeneric() == WellKnownFullyQualifiedClassNames.ClassConstructorAttribute.WithGlobalPrefix);
}

/// <summary>
/// Generates the ClassConstructor throw-stub InstanceFactory.
/// </summary>
public static void GenerateClassConstructorStub(CodeWriter writer)
{
writer.AppendLine("InstanceFactory = (typeArgs, args) =>");
writer.AppendLine("{");
writer.Indent();
writer.AppendLine("// ClassConstructor attribute is present - instance creation handled at runtime");
writer.AppendLine("throw new global::System.NotSupportedException(\"Instance creation for classes with ClassConstructor attribute is handled at runtime\");");
writer.Unindent();
writer.AppendLine("},");
}

/// <summary>
/// Generates code to create an instance of a type with proper required property handling.
/// This handles required properties that don't have data sources by initializing them with defaults.
Expand Down Expand Up @@ -46,26 +77,11 @@ public static void GenerateInstanceFactory(CodeWriter writer, ITypeSymbol typeSy
{
var className = typeSymbol.GloballyQualified();

// Check if the class has a ClassConstructor attribute first, before any other checks
if (typeSymbol is INamedTypeSymbol namedTypeSymbol)
// Check if the class has a ClassConstructor attribute first (class, base types, or assembly level)
if (typeSymbol is INamedTypeSymbol namedTypeSymbol && HasClassConstructorAttribute(namedTypeSymbol))
{
var hasClassConstructor = namedTypeSymbol.GetAttributesIncludingBaseTypes()
.Any(a => a.AttributeClass?.GloballyQualifiedNonGeneric() == WellKnownFullyQualifiedClassNames.ClassConstructorAttribute.WithGlobalPrefix);

if (hasClassConstructor)
{
// If class has ClassConstructor attribute, generate a factory that throws
// The actual instance creation will be handled by ClassConstructorHelper at runtime
// This applies to both generic and non-generic classes
writer.AppendLine("InstanceFactory = (typeArgs, args) =>");
writer.AppendLine("{");
writer.Indent();
writer.AppendLine("// ClassConstructor attribute is present - instance creation handled at runtime");
writer.AppendLine("throw new global::System.NotSupportedException(\"Instance creation for classes with ClassConstructor attribute is handled at runtime\");");
writer.Unindent();
writer.AppendLine("},");
return;
}
GenerateClassConstructorStub(writer);
return;
}

// Check if this is a generic type definition
Expand Down
215 changes: 118 additions & 97 deletions TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2979,31 +2979,38 @@ private static void GenerateGenericTestWithConcreteTypes(
GenerateMetadataForConcreteInstantiation(writer, testMethod);

// Generate instance factory that works with generic types
writer.AppendLine("InstanceFactory = static (typeArgs, args) =>");
writer.AppendLine("{");
writer.Indent();

if (testMethod.IsGenericType)
if (InstanceFactoryGenerator.HasClassConstructorAttribute(testMethod.TypeSymbol))
{
// For generic classes, we need to use runtime type construction
var openGenericTypeName = GetOpenGenericTypeName(testMethod.TypeSymbol);
writer.AppendLine($"var genericType = typeof({openGenericTypeName});");
writer.AppendLine("if (typeArgs.Length > 0)");
writer.AppendLine("{");
writer.Indent();
writer.AppendLine("var closedType = genericType.MakeGenericType(typeArgs);");
writer.AppendLine("return global::System.Activator.CreateInstance(closedType, args)!;");
writer.Unindent();
writer.AppendLine("}");
writer.AppendLine("throw new global::System.InvalidOperationException(\"No type arguments provided for generic class\");");
InstanceFactoryGenerator.GenerateClassConstructorStub(writer);
}
else
{
writer.AppendLine($"return new {className}();");
}
writer.AppendLine("InstanceFactory = static (typeArgs, args) =>");
writer.AppendLine("{");
writer.Indent();

writer.Unindent();
writer.AppendLine("},");
if (testMethod.IsGenericType)
{
// For generic classes, we need to use runtime type construction
var openGenericTypeName = GetOpenGenericTypeName(testMethod.TypeSymbol);
writer.AppendLine($"var genericType = typeof({openGenericTypeName});");
writer.AppendLine("if (typeArgs.Length > 0)");
writer.AppendLine("{");
writer.Indent();
writer.AppendLine("var closedType = genericType.MakeGenericType(typeArgs);");
writer.AppendLine("return global::System.Activator.CreateInstance(closedType, args)!;");
writer.Unindent();
writer.AppendLine("}");
writer.AppendLine("throw new global::System.InvalidOperationException(\"No type arguments provided for generic class\");");
}
else
{
writer.AppendLine($"return new {className}();");
}

writer.Unindent();
writer.AppendLine("},");
}

// Generate concrete instantiations dictionary
writer.AppendLine("ConcreteInstantiations = new global::System.Collections.Generic.Dictionary<string, global::TUnit.Core.TestMetadata>");
Expand Down Expand Up @@ -4490,53 +4497,60 @@ private static void GenerateConcreteTestMetadata(
GenerateConcreteMetadataWithFilteredDataSources(writer, testMethod, specificArgumentsAttribute, typeArguments);

// Generate instance factory
writer.AppendLine("InstanceFactory = static (typeArgs, args) =>");
writer.AppendLine("{");
writer.Indent();

// Check if the class has a constructor that requires arguments
var hasParameterizedConstructor = false;
var constructorParamCount = 0;

if (testMethod.IsGenericType)
if (InstanceFactoryGenerator.HasClassConstructorAttribute(testMethod.TypeSymbol))
{
// Find the primary constructor or first public constructor
var constructor = testMethod.TypeSymbol.Constructors
.Where(c => !c.IsStatic && c.DeclaredAccessibility == Accessibility.Public)
.OrderByDescending(c => c.Parameters.Length)
.FirstOrDefault();
InstanceFactoryGenerator.GenerateClassConstructorStub(writer);
}
else
{
writer.AppendLine("InstanceFactory = static (typeArgs, args) =>");
writer.AppendLine("{");
writer.Indent();

if (constructor is { Parameters.Length: > 0 })
// Check if the class has a constructor that requires arguments
var hasParameterizedConstructor = false;
var constructorParamCount = 0;

if (testMethod.IsGenericType)
{
hasParameterizedConstructor = true;
constructorParamCount = constructor.Parameters.Length;
// Find the primary constructor or first public constructor
var constructor = testMethod.TypeSymbol.Constructors
.Where(c => !c.IsStatic && c.DeclaredAccessibility == Accessibility.Public)
.OrderByDescending(c => c.Parameters.Length)
.FirstOrDefault();

if (constructor is { Parameters.Length: > 0 })
{
hasParameterizedConstructor = true;
constructorParamCount = constructor.Parameters.Length;
}
}
}

if (hasParameterizedConstructor)
{
// For classes with constructor parameters, use the specific constructor arguments from the Arguments attribute
if (specificArgumentsAttribute is { ConstructorArguments.Length: > 0 } &&
specificArgumentsAttribute.ConstructorArguments[0].Kind == TypedConstantKind.Array)
if (hasParameterizedConstructor)
{
var argumentValues = specificArgumentsAttribute.ConstructorArguments[0].Values;
var constructorArgs = string.Join(", ", argumentValues.Select(arg => TypedConstantParser.GetRawTypedConstantValue(arg)));
// For classes with constructor parameters, use the specific constructor arguments from the Arguments attribute
if (specificArgumentsAttribute is { ConstructorArguments.Length: > 0 } &&
specificArgumentsAttribute.ConstructorArguments[0].Kind == TypedConstantKind.Array)
{
var argumentValues = specificArgumentsAttribute.ConstructorArguments[0].Values;
var constructorArgs = string.Join(", ", argumentValues.Select(arg => TypedConstantParser.GetRawTypedConstantValue(arg)));

writer.AppendLine($"return ({concreteClassName})global::System.Activator.CreateInstance(typeof({concreteClassName}), new object[] {{ {constructorArgs} }})!;");
writer.AppendLine($"return ({concreteClassName})global::System.Activator.CreateInstance(typeof({concreteClassName}), new object[] {{ {constructorArgs} }})!;");
}
else
{
// Fallback to using args if no specific Arguments attribute
writer.AppendLine($"return ({concreteClassName})global::System.Activator.CreateInstance(typeof({concreteClassName}), args)!;");
}
}
else
{
// Fallback to using args if no specific Arguments attribute
writer.AppendLine($"return ({concreteClassName})global::System.Activator.CreateInstance(typeof({concreteClassName}), args)!;");
writer.AppendLine($"return new {concreteClassName}();");
}
}
else
{
writer.AppendLine($"return new {concreteClassName}();");
}

writer.Unindent();
writer.AppendLine("},");
writer.Unindent();
writer.AppendLine("},");
}

// Generate strongly-typed test invoker
writer.AppendLine("InvokeTypedTest = static (instance, args, cancellationToken) =>");
Expand Down Expand Up @@ -5082,59 +5096,66 @@ private static void GenerateConcreteTestMetadataForNonGeneric(
SourceInformationWriter.GenerateMethodInformation(writer, compilation, testMethod.TypeSymbol, testMethod.MethodSymbol, null, ',');

// Generate instance factory
writer.AppendLine("InstanceFactory = static (typeArgs, args) =>");
writer.AppendLine("{");
writer.Indent();

// Check if the class has a constructor that requires arguments
var hasParameterizedConstructor = false;
var constructorParamCount = 0;

// Find the primary constructor or first public constructor
var constructor = testMethod.TypeSymbol.Constructors
.Where(c => !c.IsStatic && c.DeclaredAccessibility == Accessibility.Public)
.OrderByDescending(c => c.Parameters.Length)
.FirstOrDefault();

if (constructor is { Parameters.Length: > 0 })
if (InstanceFactoryGenerator.HasClassConstructorAttribute(testMethod.TypeSymbol))
{
hasParameterizedConstructor = true;
constructorParamCount = constructor.Parameters.Length;
InstanceFactoryGenerator.GenerateClassConstructorStub(writer);
}

if (hasParameterizedConstructor)
else
{
// For classes with constructor parameters, check if we have Arguments attribute
var isArgumentsAttribute = classDataSourceAttribute?.AttributeClass?.Name == "ArgumentsAttribute";
writer.AppendLine("InstanceFactory = static (typeArgs, args) =>");
writer.AppendLine("{");
writer.Indent();

if (isArgumentsAttribute && classDataSourceAttribute is { ConstructorArguments.Length: > 0 } &&
classDataSourceAttribute.ConstructorArguments[0].Kind == TypedConstantKind.Array)
// Check if the class has a constructor that requires arguments
var hasParameterizedConstructor = false;
var constructorParamCount = 0;

// Find the primary constructor or first public constructor
var constructor = testMethod.TypeSymbol.Constructors
.Where(c => !c.IsStatic && c.DeclaredAccessibility == Accessibility.Public)
.OrderByDescending(c => c.Parameters.Length)
.FirstOrDefault();

if (constructor is { Parameters.Length: > 0 })
{
hasParameterizedConstructor = true;
constructorParamCount = constructor.Parameters.Length;
}

if (hasParameterizedConstructor)
{
var argumentValues = classDataSourceAttribute.ConstructorArguments[0].Values;
var constructorArgs = string.Join(", ", argumentValues.Select(arg => TypedConstantParser.GetRawTypedConstantValue(arg)));
// For classes with constructor parameters, check if we have Arguments attribute
var isArgumentsAttribute = classDataSourceAttribute?.AttributeClass?.Name == "ArgumentsAttribute";

writer.AppendLine($"return new {className}({constructorArgs});");
if (isArgumentsAttribute && classDataSourceAttribute is { ConstructorArguments.Length: > 0 } &&
classDataSourceAttribute.ConstructorArguments[0].Kind == TypedConstantKind.Array)
{
var argumentValues = classDataSourceAttribute.ConstructorArguments[0].Values;
var constructorArgs = string.Join(", ", argumentValues.Select(arg => TypedConstantParser.GetRawTypedConstantValue(arg)));

writer.AppendLine($"return new {className}({constructorArgs});");
}
else
{
// Use the args parameter if no specific arguments are provided
writer.AppendLine($"if (args.Length >= {constructorParamCount})");
writer.AppendLine("{");
writer.Indent();
writer.AppendLine($"return new {className}({string.Join(", ", Enumerable.Range(0, constructorParamCount).Select(i => $"args[{i}]"))});");
writer.Unindent();
writer.AppendLine("}");
writer.AppendLine("throw new global::System.InvalidOperationException(\"Not enough arguments provided for class constructor\");");
}
}
else
{
// Use the args parameter if no specific arguments are provided
writer.AppendLine($"if (args.Length >= {constructorParamCount})");
writer.AppendLine("{");
writer.Indent();
writer.AppendLine($"return new {className}({string.Join(", ", Enumerable.Range(0, constructorParamCount).Select(i => $"args[{i}]"))});");
writer.Unindent();
writer.AppendLine("}");
writer.AppendLine("throw new global::System.InvalidOperationException(\"Not enough arguments provided for class constructor\");");
// No constructor parameters needed
writer.AppendLine($"return new {className}();");
}
}
else
{
// No constructor parameters needed
writer.AppendLine($"return new {className}();");
}

writer.Unindent();
writer.AppendLine("},");
writer.Unindent();
writer.AppendLine("},");
}

// Generate typed invoker
GenerateTypedInvokers(writer, testMethod, className);
Expand Down
Loading
Loading