Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions TUnit.Engine/Building/ReflectionMetadataBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,56 @@ public static MethodMetadata CreateMethodMetadata(
Type = type,
TypeInfo = CreateTypeInfo(type),
Class = CreateClassMetadata(type),
Parameters = method.GetParameters()
.Select((p, i) => CreateParameterMetadata(p.ParameterType, p.Name ?? "unnamed", i, p))
.ToArray(),
Parameters = BuildParameterMetadata(method.GetParameters()),
GenericTypeCount = method.IsGenericMethodDefinition ? method.GetGenericArguments().Length : 0,
ReturnTypeInfo = CreateTypeInfo(method.ReturnType),
ReturnType = method.ReturnType
};
}

#if NET8_0_OR_GREATER
[RequiresUnreferencedCode("Parameter metadata creation uses reflection")]
#endif
private static ParameterMetadata[] BuildParameterMetadata(System.Reflection.ParameterInfo[] parameters)
{
if (parameters.Length == 0)
{
return [];
}

var result = new ParameterMetadata[parameters.Length];
for (var i = 0; i < parameters.Length; i++)
{
var p = parameters[i];
result[i] = CreateParameterMetadata(p.ParameterType, p.Name ?? "unnamed", i, p);
}

return result;
}

#if NET8_0_OR_GREATER
[RequiresUnreferencedCode("Parameter metadata creation uses reflection")]
#endif
private static ParameterMetadata[] BuildConstructorParameterMetadata(System.Reflection.ParameterInfo[] parameters)
{
if (parameters.Length == 0)
{
return [];
}

var result = new ParameterMetadata[parameters.Length];
for (var i = 0; i < parameters.Length; i++)
{
var p = parameters[i];
// Preserve original behaviour: ctor parameters pass p.Name through unchanged
// (CreateParameterMetadata falls back to "param{index}" when null), unlike method
// parameters which fall back to "unnamed".
result[i] = CreateParameterMetadata(p.ParameterType, p.Name, i, p);
}

return result;
}

private static TypeInfo CreateTypeInfo(Type type)
{
return new ConcreteType(type);
Expand Down Expand Up @@ -74,9 +115,9 @@ private static ClassMetadata CreateClassMetadata([DynamicallyAccessedMembers(Dyn
var constructors = type.GetConstructors(System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Instance);
var constructor = constructors.FirstOrDefault();

var constructorParameters = constructor?.GetParameters()
.Select((p, i) => CreateParameterMetadata(p.ParameterType, p.Name, i, p))
.ToArray() ?? [];
var constructorParameters = constructor is null
? []
: BuildConstructorParameterMetadata(constructor.GetParameters());

return new ClassMetadata
{
Expand Down
14 changes: 8 additions & 6 deletions TUnit.Engine/Discovery/ConstructorHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ internal static class ConstructorHelper
{
var constructors = testClass.GetConstructors(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);

// First, look for constructors marked with [TestConstructor]
var testConstructorMarked = constructors.Where(c => c.GetCustomAttribute<TestConstructorAttribute>() != null).ToArray();

if (testConstructorMarked.Length > 0)
// First, look for the first constructor marked with [TestConstructor] (stop early instead
// of building a filtered array just to read index 0).
foreach (var constructor in constructors)
{
return testConstructorMarked[0];
if (constructor.GetCustomAttribute<TestConstructorAttribute>() != null)
{
return constructor;
}
}

// If no [TestConstructor] attribute found, return the first instance constructor
return constructors.FirstOrDefault();
return constructors.Length > 0 ? constructors[0] : null;
}

/// <summary>
Expand Down
82 changes: 59 additions & 23 deletions TUnit.Engine/Discovery/ReflectionHookDiscoveryService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,25 +128,32 @@ public static void DiscoverInstanceHooksForType(Type closedGenericType)
// Discover hooks in each type in the inheritance chain, from base to derived
foreach (var typeInChain in inheritanceChain)
{
var methods = typeInChain.GetMethods(BindingFlags.Public | BindingFlags.Instance | BindingFlags.DeclaredOnly)
.OrderBy(m =>
{
// Get the minimum order from cached hook attributes
var (beforeAttr, afterAttr, beforeEveryAttr, afterEveryAttr) = GetCachedAttributes(m);
var declaredMethods = typeInChain.GetMethods(BindingFlags.Public | BindingFlags.Instance | BindingFlags.DeclaredOnly);

// Pre-compute the sort keys once (the old OrderBy lambda allocated a List<int> and
// recomputed attributes on every comparison). Sort by minimum hook order, then by
// MetadataToken to preserve source-file order. MetadataToken is unique per method
// within a type, so the (order, token) key is total — Array.Sort needs no extra
// stability guarantee to reproduce the previous OrderBy/ThenBy result.
var sortKeys = new HookMethodSortKey[declaredMethods.Length];
for (var i = 0; i < declaredMethods.Length; i++)
{
var m = declaredMethods[i];
var (beforeAttr, afterAttr, beforeEveryAttr, afterEveryAttr) = GetCachedAttributes(m);

var orders = new List<int>();
if (beforeAttr != null) orders.Add(beforeAttr.Order);
if (afterAttr != null) orders.Add(afterAttr.Order);
if (beforeEveryAttr != null) orders.Add(beforeEveryAttr.Order);
if (afterEveryAttr != null) orders.Add(afterEveryAttr.Order);
var hasOrder = false;
var minOrder = 0;
if (beforeAttr != null) { minOrder = beforeAttr.Order; hasOrder = true; }
if (afterAttr != null) { minOrder = hasOrder ? Math.Min(minOrder, afterAttr.Order) : afterAttr.Order; hasOrder = true; }
if (beforeEveryAttr != null) { minOrder = hasOrder ? Math.Min(minOrder, beforeEveryAttr.Order) : beforeEveryAttr.Order; hasOrder = true; }
if (afterEveryAttr != null) { minOrder = hasOrder ? Math.Min(minOrder, afterEveryAttr.Order) : afterEveryAttr.Order; }

// Use Count instead of Any() to avoid double enumeration
return orders.Count > 0 ? orders.Min() : 0;
})
.ThenBy(static m => m.MetadataToken) // Then sort by MetadataToken to preserve source file order
.ToArray();
sortKeys[i] = new HookMethodSortKey(hasOrder ? minOrder : 0, m.MetadataToken);
}

foreach (var method in methods)
Array.Sort(sortKeys, declaredMethods);

foreach (var method in declaredMethods)
{
// Check for Before attributes
var beforeAttributes = method.GetCustomAttributes<BeforeAttribute>(false);
Expand Down Expand Up @@ -847,20 +854,49 @@ private static MethodMetadata CreateMethodMetadata(
Name = method.Name,
Type = type,
Class = CreateClassMetadata(type),
Parameters = method.GetParameters().Select(p => new ParameterMetadata(p.ParameterType)
{
Name = p.Name ?? string.Empty,
Type = p.ParameterType,
TypeInfo = new ConcreteType(p.ParameterType),
ReflectionInfo = p
}).ToArray(),
Parameters = BuildParameterMetadata(method.GetParameters()),
GenericTypeCount = 0,
ReturnTypeInfo = new ConcreteType(method.ReturnType),
ReturnType = method.ReturnType,
TypeInfo = new ConcreteType(type)
};
}

private readonly struct HookMethodSortKey(int order, int metadataToken) : IComparable<HookMethodSortKey>
{
private readonly int _order = order;
private readonly int _metadataToken = metadataToken;

public int CompareTo(HookMethodSortKey other)
{
var orderComparison = _order.CompareTo(other._order);
return orderComparison != 0 ? orderComparison : _metadataToken.CompareTo(other._metadataToken);
}
}

private static ParameterMetadata[] BuildParameterMetadata(ParameterInfo[] parameters)
{
if (parameters.Length == 0)
{
return [];
}

var result = new ParameterMetadata[parameters.Length];
for (var i = 0; i < parameters.Length; i++)
{
var p = parameters[i];
result[i] = new ParameterMetadata(p.ParameterType)
{
Name = p.Name ?? string.Empty,
Type = p.ParameterType,
TypeInfo = new ConcreteType(p.ParameterType),
ReflectionInfo = p
};
}

return result;
}

private static ClassMetadata CreateClassMetadata(Type type)
{
return ClassMetadata.GetOrAdd(type.FullName ?? type.Name, () => new ClassMetadata
Expand Down
135 changes: 82 additions & 53 deletions TUnit.Engine/Discovery/ReflectionTestDataCollector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ private static bool ShouldScanAssembly(Assembly assembly)
return false;
}

if (name.EndsWith(".resources") || name.EndsWith(".XmlSerializers"))
if (name.EndsWith(".resources", StringComparison.Ordinal) || name.EndsWith(".XmlSerializers", StringComparison.Ordinal))
{
return false;
}
Expand Down Expand Up @@ -333,7 +333,7 @@ private static bool ShouldScanAssembly(Assembly assembly)
var hasTUnitReference = false;
foreach (var reference in referencedAssemblies)
{
if (reference.Name != null && (reference.Name.StartsWith("TUnit") || reference.Name == "TUnit"))
if (reference.Name != null && (reference.Name.StartsWith("TUnit", StringComparison.Ordinal) || reference.Name == "TUnit"))
{
hasTUnitReference = true;
break;
Expand All @@ -348,6 +348,45 @@ private static bool ShouldScanAssembly(Assembly assembly)
return true;
}

/// <summary>
/// Filters the non-null entries out of <see cref="ReflectionTypeLoadException.Types"/>, using a
/// pooled scratch buffer to avoid the per-failure allocations of the old LINQ Where().ToArray().
/// Shared by both the eager and streaming discovery type caches.
/// </summary>
private static Type[] FilterLoadedTypes(Type?[]? loadedTypes)
{
if (loadedTypes is null || loadedTypes.Length == 0)
{
return [];
}

var tempArray = ArrayPool<Type>.Shared.Rent(loadedTypes.Length);
try
{
var validCount = 0;
foreach (var type in loadedTypes)
{
if (type != null)
{
tempArray[validCount++] = type;
}
}

if (validCount == 0)
{
return [];
}

var result = new Type[validCount];
Array.Copy(tempArray, result, validCount);
return result;
}
finally
{
ArrayPool<Type>.Shared.Return(tempArray);
}
}

private static async Task<List<TestMetadata>> DiscoverTestsInAssembly(Assembly assembly)
{
var discoveredTests = new List<TestMetadata>(100);
Expand All @@ -360,7 +399,7 @@ private static async Task<List<TestMetadata>> DiscoverTestsInAssembly(Assembly a
}
catch (ReflectionTypeLoadException reflectionTypeLoadException)
{
return reflectionTypeLoadException.Types.Where(static x => x != null).ToArray()!;
return FilterLoadedTypes(reflectionTypeLoadException.Types);
}
catch (Exception)
{
Expand Down Expand Up @@ -475,35 +514,8 @@ private static async IAsyncEnumerable<TestMetadata> DiscoverTestsInAssemblyStrea
}
catch (ReflectionTypeLoadException rtle)
{
// Some types might fail to load, but we can still use the ones that loaded successfully
// Optimize: Manual filtering with ArrayPool for better memory efficiency
var loadedTypes = rtle.Types;
if (loadedTypes == null)
{
return [];
}

// Use ArrayPool for temporary storage to reduce allocations
var tempArray = ArrayPool<Type>.Shared.Rent(loadedTypes.Length);
try
{
var validCount = 0;
foreach (var type in loadedTypes)
{
if (type != null)
{
tempArray[validCount++] = type;
}
}

var result = new Type[validCount];
Array.Copy(tempArray, result, validCount);
return result;
}
finally
{
ArrayPool<Type>.Shared.Return(tempArray);
}
// Some types might fail to load, but we can still use the ones that loaded successfully.
return FilterLoadedTypes(rtle.Types);
}
catch (Exception)
{
Expand Down Expand Up @@ -1659,39 +1671,56 @@ private static bool IsCovariantCompatible(Type paramType, [DynamicallyAccessedMe

var paramGenericDef = paramType.GetGenericTypeDefinition();

// List of known covariant interfaces
var covariantInterfaces = new[]
{
typeof(IEnumerable<>),
typeof(IReadOnlyList<>),
typeof(IReadOnlyCollection<>),
typeof(IEnumerator<>)
};

if (!covariantInterfaces.Contains(paramGenericDef))
if (Array.IndexOf(CovariantInterfaces, paramGenericDef) < 0)
{
return false;
}

// Check the argument type's interfaces, plus the argument type itself when it's an interface,
// without allocating a combined array (the old Concat([argType]).ToArray() per call).
var argInterfaces = AssemblyReferenceCache.GetInterfaces(argType);
if (argType.IsInterface)
foreach (var iface in argInterfaces)
{
argInterfaces = argInterfaces.Concat([argType]).ToArray();
if (MatchesCovariantInterface(paramType, paramGenericDef, iface, out var compatible))
{
return compatible;
}
}

foreach (var iface in argInterfaces)
if (argType.IsInterface && MatchesCovariantInterface(paramType, paramGenericDef, argType, out var argCompatible))
{
if (iface.IsGenericType && iface.GetGenericTypeDefinition() == paramGenericDef)
{
var paramElementType = paramType.GetGenericArguments()[0];
var argElementType = iface.GetGenericArguments()[0];
return argCompatible;
}

// For covariance to work, the parameter element type must be assignable from the argument element type
// This allows IEnumerable<int> to be passed where IEnumerable<object> is expected
return paramElementType.IsAssignableFrom(argElementType);
}
return false;
}

/// <summary>
/// Known covariant generic interface definitions. Hoisted to avoid allocating the array per
/// <see cref="IsCovariantCompatible"/> call.
/// </summary>
private static readonly Type[] CovariantInterfaces =
[
typeof(IEnumerable<>),
typeof(IReadOnlyList<>),
typeof(IReadOnlyCollection<>),
typeof(IEnumerator<>)
];

private static bool MatchesCovariantInterface(Type paramType, Type paramGenericDef, Type iface, out bool compatible)
{
if (iface.IsGenericType && iface.GetGenericTypeDefinition() == paramGenericDef)
{
var paramElementType = paramType.GetGenericArguments()[0];
var argElementType = iface.GetGenericArguments()[0];

// For covariance to work, the parameter element type must be assignable from the argument element type
// This allows IEnumerable<int> to be passed where IEnumerable<object> is expected
compatible = paramElementType.IsAssignableFrom(argElementType);
return true;
}

compatible = false;
return false;
}

Expand Down
Loading
Loading