From bec37064829134d7b743f3b26b7be694dc63e86c Mon Sep 17 00:00:00 2001
From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com>
Date: Fri, 1 May 2026 17:43:36 +0100
Subject: [PATCH 1/4] fix(should): add specialized assertion sources
---
...ng_Should_overload.DotNet10_0.verified.txt | 24 +
...ing_Should_overload.DotNet8_0.verified.txt | 24 +
...ing_Should_overload.DotNet9_0.verified.txt | 24 +
...ng_Should_overload.DotNet10_0.verified.txt | 23 +
...ing_Should_overload.DotNet8_0.verified.txt | 23 +
...ing_Should_overload.DotNet9_0.verified.txt | 23 +
.../ShouldExtensionGeneratorTests.cs | 78 +++
.../ShouldExtensionGenerator.cs | 648 +++++++++++++++++-
.../CollectionTests.cs | 26 +
.../Core/ShouldCollectionSource.cs | 71 +-
.../Core/ShouldDelegateCollectionSource.cs | 146 ++++
.../Core/ShouldDictionarySource.cs | 246 +++++++
.../Core/ShouldSetSource.cs | 79 +++
.../Core/ShouldSourceBase.cs | 166 +++++
TUnit.Assertions.Should/ShouldExtensions.cs | 78 ++-
...Has_No_API_Changes.DotNet10_0.verified.txt | 161 ++++-
..._Has_No_API_Changes.DotNet8_0.verified.txt | 146 +++-
..._Has_No_API_Changes.DotNet9_0.verified.txt | 161 ++++-
18 files changed, 2052 insertions(+), 95 deletions(-)
create mode 100644 TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet10_0.verified.txt
create mode 100644 TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet8_0.verified.txt
create mode 100644 TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet9_0.verified.txt
create mode 100644 TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet10_0.verified.txt
create mode 100644 TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet8_0.verified.txt
create mode 100644 TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet9_0.verified.txt
create mode 100644 TUnit.Assertions.Should/Core/ShouldDelegateCollectionSource.cs
create mode 100644 TUnit.Assertions.Should/Core/ShouldDictionarySource.cs
create mode 100644 TUnit.Assertions.Should/Core/ShouldSetSource.cs
create mode 100644 TUnit.Assertions.Should/Core/ShouldSourceBase.cs
diff --git a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet10_0.verified.txt b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet10_0.verified.txt
new file mode 100644
index 0000000000..712c765f98
--- /dev/null
+++ b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet10_0.verified.txt
@@ -0,0 +1,24 @@
+//
+#nullable enable
+
+using System;
+using System.Runtime.CompilerServices;
+using TUnit.Assertions;
+using TUnit.Assertions.Should.Core;
+
+namespace TUnit.Assertions.Should;
+
+public static partial class ShouldExtensions
+{
+
+ [global::System.Runtime.CompilerServices.OverloadResolutionPriority(2)]
+ public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.IReadOnlyDictionary? value, string? expression = default)
+ where TKey : notnull
+ {
+ var source = global::TUnit.Assertions.Assert.That(value, expression);
+ var innerContext = ((global::TUnit.Assertions.Core.IAssertionSource>)source).Context;
+ innerContext.ExpressionBuilder.Clear();
+ innerContext.ExpressionBuilder.Append(expression ?? "?").Append(".Should()");
+ return new global::TUnit.Assertions.Should.Core.ShouldSource>(innerContext);
+ }
+}
diff --git a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet8_0.verified.txt b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet8_0.verified.txt
new file mode 100644
index 0000000000..712c765f98
--- /dev/null
+++ b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet8_0.verified.txt
@@ -0,0 +1,24 @@
+//
+#nullable enable
+
+using System;
+using System.Runtime.CompilerServices;
+using TUnit.Assertions;
+using TUnit.Assertions.Should.Core;
+
+namespace TUnit.Assertions.Should;
+
+public static partial class ShouldExtensions
+{
+
+ [global::System.Runtime.CompilerServices.OverloadResolutionPriority(2)]
+ public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.IReadOnlyDictionary? value, string? expression = default)
+ where TKey : notnull
+ {
+ var source = global::TUnit.Assertions.Assert.That(value, expression);
+ var innerContext = ((global::TUnit.Assertions.Core.IAssertionSource>)source).Context;
+ innerContext.ExpressionBuilder.Clear();
+ innerContext.ExpressionBuilder.Append(expression ?? "?").Append(".Should()");
+ return new global::TUnit.Assertions.Should.Core.ShouldSource>(innerContext);
+ }
+}
diff --git a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet9_0.verified.txt b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet9_0.verified.txt
new file mode 100644
index 0000000000..712c765f98
--- /dev/null
+++ b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet9_0.verified.txt
@@ -0,0 +1,24 @@
+//
+#nullable enable
+
+using System;
+using System.Runtime.CompilerServices;
+using TUnit.Assertions;
+using TUnit.Assertions.Should.Core;
+
+namespace TUnit.Assertions.Should;
+
+public static partial class ShouldExtensions
+{
+
+ [global::System.Runtime.CompilerServices.OverloadResolutionPriority(2)]
+ public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.IReadOnlyDictionary? value, string? expression = default)
+ where TKey : notnull
+ {
+ var source = global::TUnit.Assertions.Assert.That(value, expression);
+ var innerContext = ((global::TUnit.Assertions.Core.IAssertionSource>)source).Context;
+ innerContext.ExpressionBuilder.Clear();
+ innerContext.ExpressionBuilder.Append(expression ?? "?").Append(".Should()");
+ return new global::TUnit.Assertions.Should.Core.ShouldSource>(innerContext);
+ }
+}
diff --git a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet10_0.verified.txt b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet10_0.verified.txt
new file mode 100644
index 0000000000..25318db419
--- /dev/null
+++ b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet10_0.verified.txt
@@ -0,0 +1,23 @@
+//
+#nullable enable
+
+using System;
+using System.Runtime.CompilerServices;
+using TUnit.Assertions;
+using TUnit.Assertions.Should.Core;
+
+namespace TUnit.Assertions.Should;
+
+public static partial class ShouldExtensions
+{
+
+ [global::System.Runtime.CompilerServices.OverloadResolutionPriority(3)]
+ public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.ISet? value, string? expression = default)
+ {
+ var source = global::TUnit.Assertions.Assert.That(value, expression);
+ var innerContext = ((global::TUnit.Assertions.Core.IAssertionSource>)source).Context;
+ innerContext.ExpressionBuilder.Clear();
+ innerContext.ExpressionBuilder.Append(expression ?? "?").Append(".Should()");
+ return new global::TUnit.Assertions.Should.Core.ShouldSource>(innerContext);
+ }
+}
diff --git a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet8_0.verified.txt b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet8_0.verified.txt
new file mode 100644
index 0000000000..25318db419
--- /dev/null
+++ b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet8_0.verified.txt
@@ -0,0 +1,23 @@
+//
+#nullable enable
+
+using System;
+using System.Runtime.CompilerServices;
+using TUnit.Assertions;
+using TUnit.Assertions.Should.Core;
+
+namespace TUnit.Assertions.Should;
+
+public static partial class ShouldExtensions
+{
+
+ [global::System.Runtime.CompilerServices.OverloadResolutionPriority(3)]
+ public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.ISet? value, string? expression = default)
+ {
+ var source = global::TUnit.Assertions.Assert.That(value, expression);
+ var innerContext = ((global::TUnit.Assertions.Core.IAssertionSource>)source).Context;
+ innerContext.ExpressionBuilder.Clear();
+ innerContext.ExpressionBuilder.Append(expression ?? "?").Append(".Should()");
+ return new global::TUnit.Assertions.Should.Core.ShouldSource>(innerContext);
+ }
+}
diff --git a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet9_0.verified.txt b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet9_0.verified.txt
new file mode 100644
index 0000000000..25318db419
--- /dev/null
+++ b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet9_0.verified.txt
@@ -0,0 +1,23 @@
+//
+#nullable enable
+
+using System;
+using System.Runtime.CompilerServices;
+using TUnit.Assertions;
+using TUnit.Assertions.Should.Core;
+
+namespace TUnit.Assertions.Should;
+
+public static partial class ShouldExtensions
+{
+
+ [global::System.Runtime.CompilerServices.OverloadResolutionPriority(3)]
+ public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.ISet? value, string? expression = default)
+ {
+ var source = global::TUnit.Assertions.Assert.That(value, expression);
+ var innerContext = ((global::TUnit.Assertions.Core.IAssertionSource>)source).Context;
+ innerContext.ExpressionBuilder.Clear();
+ innerContext.ExpressionBuilder.Append(expression ?? "?").Append(".Should()");
+ return new global::TUnit.Assertions.Should.Core.ShouldSource>(innerContext);
+ }
+}
diff --git a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.cs b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.cs
index 693db3a3c7..90fcd8778e 100644
--- a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.cs
+++ b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.cs
@@ -240,6 +240,84 @@ public static MyBetweenAssertion IsBetween(
await Assert.That(output).Contains("CallerArgumentExpression(\"max\")");
}
+ [Test]
+ public async Task Assert_That_dictionary_specialization_emits_matching_Should_overload()
+ {
+ var output = await RunGenerator("""
+ using System.Collections.Generic;
+ using System.Runtime.CompilerServices;
+ using TUnit.Assertions.Core;
+
+ namespace TUnit.Assertions;
+
+ public class DictionaryAssertion : IAssertionSource>
+ where TKey : notnull
+ {
+ public AssertionContext> Context { get; }
+ public DictionaryAssertion(IReadOnlyDictionary? value, string? expression)
+ => Context = new AssertionContext>(value!, new System.Text.StringBuilder());
+ public TypeOfAssertion, TExpected> IsTypeOf() => throw new System.NotImplementedException();
+ public IsNotTypeOfAssertion, TExpected> IsNotTypeOf() => throw new System.NotImplementedException();
+ public IsAssignableToAssertion> IsAssignableTo() => throw new System.NotImplementedException();
+ public IsNotAssignableToAssertion> IsNotAssignableTo() => throw new System.NotImplementedException();
+ public IsAssignableFromAssertion> IsAssignableFrom() => throw new System.NotImplementedException();
+ public IsNotAssignableFromAssertion> IsNotAssignableFrom() => throw new System.NotImplementedException();
+ }
+
+ public static class Assert
+ {
+ [OverloadResolutionPriority(2)]
+ public static DictionaryAssertion That(
+ IReadOnlyDictionary? value,
+ [CallerArgumentExpression(nameof(value))] string? expression = null)
+ where TKey : notnull
+ => new(value, expression);
+ }
+ """);
+
+ await Assert.That(output).Contains("public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.IReadOnlyDictionary? value");
+ await Assert.That(output).Contains("Assert.That(value, expression)");
+ await Assert.That(output).Contains("Append(expression ?? \"?\").Append(\".Should()\")");
+ }
+
+ [Test]
+ public async Task Assert_That_set_specialization_emits_matching_Should_overload()
+ {
+ var output = await RunGenerator("""
+ using System.Collections.Generic;
+ using System.Runtime.CompilerServices;
+ using TUnit.Assertions.Core;
+
+ namespace TUnit.Assertions;
+
+ public class SetAssertion : IAssertionSource>
+ {
+ public AssertionContext> Context { get; }
+ public SetAssertion(ISet? value, string? expression)
+ => Context = new AssertionContext>(value!, new System.Text.StringBuilder());
+ public TypeOfAssertion, TExpected> IsTypeOf() => throw new System.NotImplementedException();
+ public IsNotTypeOfAssertion, TExpected> IsNotTypeOf() => throw new System.NotImplementedException();
+ public IsAssignableToAssertion> IsAssignableTo() => throw new System.NotImplementedException();
+ public IsNotAssignableToAssertion> IsNotAssignableTo() => throw new System.NotImplementedException();
+ public IsAssignableFromAssertion> IsAssignableFrom() => throw new System.NotImplementedException();
+ public IsNotAssignableFromAssertion> IsNotAssignableFrom() => throw new System.NotImplementedException();
+ }
+
+ public static class Assert
+ {
+ [OverloadResolutionPriority(3)]
+ public static SetAssertion That(
+ ISet? value,
+ [CallerArgumentExpression(nameof(value))] string? expression = null)
+ => new(value, expression);
+ }
+ """);
+
+ await Assert.That(output).Contains("public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.ISet? value");
+ await Assert.That(output).Contains("OverloadResolutionPriority(3)");
+ await Assert.That(output).Contains("Assert.That(value, expression)");
+ }
+
///
/// Compiles with the Should-generator's input dependencies,
/// runs , snapshots the full generated source via
diff --git a/TUnit.Assertions.Should.SourceGenerator/ShouldExtensionGenerator.cs b/TUnit.Assertions.Should.SourceGenerator/ShouldExtensionGenerator.cs
index ad54707268..05b3ffe309 100644
--- a/TUnit.Assertions.Should.SourceGenerator/ShouldExtensionGenerator.cs
+++ b/TUnit.Assertions.Should.SourceGenerator/ShouldExtensionGenerator.cs
@@ -30,6 +30,8 @@ public sealed class ShouldExtensionGenerator : IIncrementalGenerator
private const string AssertionSourceFullName = "TUnit.Assertions.Core.IAssertionSource`1";
private const string AssertionBaseFullName = "TUnit.Assertions.Core.Assertion`1";
private const string AssertionContextFullName = "TUnit.Assertions.Core.AssertionContext`1";
+ private const string AssertFullName = "TUnit.Assertions.Assert";
+ private const string CallerArgumentExpressionAttributeFullName = "System.Runtime.CompilerServices.CallerArgumentExpressionAttribute";
private const string ShouldExtensionsNamespace = "TUnit.Assertions.Should.Extensions";
private const string ShouldNameAttributeFullName = "TUnit.Assertions.Should.Attributes.ShouldNameAttribute";
private const string CallerArgumentExpressionAttributeName = "CallerArgumentExpressionAttribute";
@@ -63,6 +65,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
{
var emittedHints = new HashSet(StringComparer.Ordinal);
+ if (payload.Entries.Length > 0)
+ {
+ EmitShouldEntries(ctx, payload.Entries.ToArray(), emittedHints);
+ }
+
// Wrappers first: they own the return types they cover, and their method names
// win over extension methods at call sites anyway.
foreach (var wrapper in payload.Wrappers)
@@ -101,7 +108,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
private sealed record ReferenceData(
EquatableArray Methods,
EquatableArray Wrappers,
- EquatableArray AlreadyBakedNames);
+ EquatableArray Entries,
+ EquatableArray AlreadyBakedNames,
+ EquatableArray BakedShouldEntryKeys);
private static GeneratorPayload Collect(Compilation compilation)
{
@@ -110,12 +119,14 @@ private static GeneratorPayload Collect(Compilation compilation)
var assertionContext = compilation.GetTypeByMetadataName(AssertionContextFullName);
var shouldNameAttr = compilation.GetTypeByMetadataName(ShouldNameAttributeFullName);
var partialMarker = compilation.GetTypeByMetadataName(ShouldGeneratePartialAttributeFullName);
+ var assertType = compilation.GetTypeByMetadataName(AssertFullName);
if (assertionSource is null || assertionBase is null || assertionContext is null)
{
return new GeneratorPayload(
new EquatableArray(Array.Empty()),
- new EquatableArray(Array.Empty()));
+ new EquatableArray(Array.Empty()),
+ new EquatableArray(Array.Empty()));
}
// Phase 1 — per-reference scan, cached by MetadataReference identity.
@@ -131,7 +142,7 @@ private static GeneratorPayload Collect(Compilation compilation)
continue;
}
refResults.Add(GetOrComputeReferenceData(
- compilation, refAssembly, assertionSource, assertionBase, assertionContext, shouldNameAttr, partialMarker));
+ compilation, refAssembly, assertionSource, assertionBase, assertionContext, shouldNameAttr, partialMarker, assertType));
}
// Phase 2 — union dedup sets across all references.
@@ -162,6 +173,24 @@ private static GeneratorPayload Collect(Compilation compilation)
WalkForWrappers(compilation.Assembly.GlobalNamespace, partialMarker, assertionBase, assertionContext, localWrappers, isCurrentAssembly: true);
}
+ var localEntries = ImmutableArray.CreateBuilder();
+ if (assertType is not null && SymbolEqualityComparer.Default.Equals(assertType.ContainingAssembly, compilation.Assembly))
+ {
+ CollectShouldEntries(assertType, assertionSource, assertionBase, assertionContext, localEntries);
+ }
+
+ var currentShouldEntryKeys = new HashSet(StringComparer.Ordinal);
+ CollectExistingShouldEntryKeys(compilation.Assembly.GlobalNamespace, currentShouldEntryKeys);
+
+ var bakedReferencedShouldEntryKeys = new HashSet(StringComparer.Ordinal);
+ foreach (var r in refResults)
+ {
+ foreach (var bakedKey in r.BakedShouldEntryKeys)
+ {
+ bakedReferencedShouldEntryKeys.Add(bakedKey);
+ }
+ }
+
// Phase 4 — merge and apply post-walk dedup. Wrapper instance methods and Should-flavored
// extensions co-exist by design: the wrapper's [ShouldGeneratePartial] only emits methods
// whose source overload exactly matches a public ctor on the inner assertion (the simple-
@@ -186,9 +215,31 @@ private static GeneratorPayload Collect(Compilation compilation)
}
}
+ var allEntries = ImmutableArray.CreateBuilder();
+ foreach (var entry in localEntries)
+ {
+ if (!currentShouldEntryKeys.Contains(entry.SignatureKey))
+ {
+ allEntries.Add(entry);
+ }
+ }
+ foreach (var r in refResults)
+ {
+ foreach (var entry in r.Entries)
+ {
+ var signatureKey = entry.SignatureKey;
+ if (!currentShouldEntryKeys.Contains(signatureKey)
+ && !bakedReferencedShouldEntryKeys.Contains(signatureKey))
+ {
+ allEntries.Add(entry);
+ }
+ }
+ }
+
return new GeneratorPayload(
- new EquatableArray(allMethods.ToArray()),
- new EquatableArray(localWrappers.ToArray()));
+ new EquatableArray(DeduplicateMethods(allMethods.ToArray())),
+ new EquatableArray(localWrappers.ToArray()),
+ new EquatableArray(DeduplicateEntries(allEntries.ToArray())));
}
///
@@ -204,12 +255,13 @@ private static ReferenceData GetOrComputeReferenceData(
INamedTypeSymbol assertionBase,
INamedTypeSymbol assertionContext,
INamedTypeSymbol? shouldNameAttr,
- INamedTypeSymbol? partialMarker)
+ INamedTypeSymbol? partialMarker,
+ INamedTypeSymbol? assertType)
{
var metadataRef = compilation.GetMetadataReference(refAssembly);
if (metadataRef is null)
{
- return ScanReference(refAssembly, compilation, assertionSource, assertionBase, assertionContext, shouldNameAttr, partialMarker);
+ return ScanReference(refAssembly, compilation, assertionSource, assertionBase, assertionContext, shouldNameAttr, partialMarker, assertType);
}
if (s_referenceCache.TryGetValue(metadataRef, out var cached))
@@ -217,7 +269,7 @@ private static ReferenceData GetOrComputeReferenceData(
return cached;
}
- var fresh = ScanReference(refAssembly, compilation, assertionSource, assertionBase, assertionContext, shouldNameAttr, partialMarker);
+ var fresh = ScanReference(refAssembly, compilation, assertionSource, assertionBase, assertionContext, shouldNameAttr, partialMarker, assertType);
// Concurrent races between two compilations seeing the same uncached MetadataReference
// are harmless — both compute the same ReferenceData; first writer wins. Catching
@@ -240,7 +292,8 @@ private static ReferenceData ScanReference(
INamedTypeSymbol assertionBase,
INamedTypeSymbol assertionContext,
INamedTypeSymbol? shouldNameAttr,
- INamedTypeSymbol? partialMarker)
+ INamedTypeSymbol? partialMarker,
+ INamedTypeSymbol? assertType)
{
var methods = ImmutableArray.CreateBuilder();
var ctx = new CollectionContext(
@@ -259,6 +312,12 @@ private static ReferenceData ScanReference(
WalkForWrappers(refAssembly.GlobalNamespace, partialMarker, assertionBase, assertionContext, wrappers, isCurrentAssembly: false);
}
+ var entries = ImmutableArray.CreateBuilder();
+ if (assertType is not null && SymbolEqualityComparer.Default.Equals(assertType.ContainingAssembly, refAssembly))
+ {
+ CollectShouldEntries(assertType, assertionSource, assertionBase, assertionContext, entries);
+ }
+
var bakedNames = new List();
var bakedNs = LookupNamespace(refAssembly.GlobalNamespace, ShouldExtensionsNamespace);
if (bakedNs is not null)
@@ -269,10 +328,15 @@ private static ReferenceData ScanReference(
}
}
+ var bakedShouldEntryKeys = new HashSet(StringComparer.Ordinal);
+ CollectExistingShouldEntryKeys(refAssembly.GlobalNamespace, bakedShouldEntryKeys);
+
return new ReferenceData(
- new EquatableArray(methods.ToArray()),
+ new EquatableArray(DeduplicateMethods(methods.ToArray())),
new EquatableArray(wrappers.ToArray()),
- new EquatableArray(bakedNames.ToArray()));
+ new EquatableArray(entries.ToArray()),
+ new EquatableArray(bakedNames.ToArray()),
+ new EquatableArray(bakedShouldEntryKeys.ToArray()));
}
private static void WalkForWrappers(
@@ -338,11 +402,17 @@ private static void CollectWrapper(
// The IsCurrentAssembly flag on WrapperData controls whether the emission step actually
// generates partial methods (only true for wrappers in this compilation).
var methods = ImmutableArray.CreateBuilder();
+ var existingWrapperMethodKeys = isCurrentAssembly
+ ? CollectExistingWrapperMethodKeys(type)
+ : null;
foreach (var sourceMember in EnumerateInstanceMethods(wrappedType))
{
if (TryDescribeWrapperMethod(sourceMember, wrappedAssertionTypeArg, assertionBase, assertionContext, out var data))
{
- methods.Add(data);
+ if (existingWrapperMethodKeys is null || existingWrapperMethodKeys.Add(GetWrapperMethodKey(data)))
+ {
+ methods.Add(data);
+ }
}
}
@@ -481,6 +551,50 @@ private static bool TryDescribeWrapperMethod(
return true;
}
+ private static HashSet CollectExistingWrapperMethodKeys(INamedTypeSymbol type)
+ {
+ var keys = new HashSet(StringComparer.Ordinal);
+ foreach (var method in type.GetMembers().OfType())
+ {
+ if (method.MethodKind != MethodKind.Ordinary || method.IsStatic)
+ {
+ continue;
+ }
+
+ keys.Add(GetWrapperMethodKey(method));
+ }
+
+ return keys;
+ }
+
+ private static string GetWrapperMethodKey(WrapperMethodData method)
+ {
+ var sb = new StringBuilder(NameConjugator.Conjugate(method.SourceMethodName))
+ .Append('|')
+ .Append('0');
+
+ foreach (var parameter in method.Parameters)
+ {
+ sb.Append('|').Append(parameter.TypeName);
+ }
+
+ return sb.ToString();
+ }
+
+ private static string GetWrapperMethodKey(IMethodSymbol method)
+ {
+ var sb = new StringBuilder(method.Name)
+ .Append('|')
+ .Append(method.TypeParameters.Length);
+
+ foreach (var parameter in method.Parameters)
+ {
+ sb.Append('|').Append(parameter.Type.ToDisplayString(NoGlobalFormat));
+ }
+
+ return sb.ToString();
+ }
+
///
/// Pre-filter: returns true only when the reference (or one of its direct module references)
/// is itself. The check is one-level deep, NOT transitive — a
@@ -1106,9 +1220,517 @@ p.DynamicallyAccessedMembersAttribute is null
private static string FormatGenericArgs(EquatableArray args)
=> args.Length == 0 ? string.Empty : "<" + string.Join(", ", args) + ">";
+ private static MethodData[] DeduplicateMethods(MethodData[] methods)
+ {
+ var seen = new HashSet(StringComparer.Ordinal);
+ var result = new List(methods.Length);
+ foreach (var method in methods)
+ {
+ if (seen.Add(GetMethodKey(method)))
+ {
+ result.Add(method);
+ }
+ }
+
+ return result.ToArray();
+ }
+
+ private static string GetMethodKey(MethodData method)
+ {
+ var sb = new StringBuilder(method.ContainerName)
+ .Append('|')
+ .Append(method.MethodName)
+ .Append('|')
+ .Append(method.SourceTypeArgDisplay)
+ .Append('|')
+ .Append(method.AssertionTypeArgDisplay)
+ .Append('|')
+ .Append(method.ReturnTypeFullName)
+ .Append('|')
+ .Append(method.MethodGenericParams.Length);
+
+ foreach (var typeArg in method.ReturnTypeGenericArgs)
+ {
+ sb.Append('|').Append(typeArg);
+ }
+
+ foreach (var parameter in method.Parameters)
+ {
+ sb.Append('|')
+ .Append(parameter.TypeName)
+ .Append(':')
+ .Append(parameter.Name)
+ .Append(':')
+ .Append(parameter.CallerArgumentExpressionTarget);
+ }
+
+ return sb.ToString();
+ }
+
+ private static ShouldEntryData[] DeduplicateEntries(ShouldEntryData[] entries)
+ {
+ var seen = new HashSet(StringComparer.Ordinal);
+ var result = new List(entries.Length);
+ foreach (var entry in entries)
+ {
+ if (seen.Add(GetShouldEntryKey(entry)))
+ {
+ result.Add(entry);
+ }
+ }
+
+ return result.ToArray();
+ }
+
+ private static string GetShouldEntryKey(ShouldEntryData entry)
+ {
+ var sb = new StringBuilder(entry.ReceiverTypeName)
+ .Append('|')
+ .Append(entry.ReceiverTypeName)
+ .Append('|')
+ .Append(entry.Priority)
+ .Append('|')
+ .Append(entry.MethodGenericParams.Length);
+
+ foreach (var p in entry.Parameters)
+ {
+ sb.Append('|').Append(p.TypeName).Append(':').Append(p.Name);
+ }
+
+ return sb.ToString();
+ }
+
+ private static void CollectExistingShouldEntryKeys(INamespaceSymbol ns, HashSet keys)
+ {
+ foreach (var type in ns.GetTypeMembers())
+ {
+ CollectExistingShouldEntryKeys(type, keys);
+ }
+
+ foreach (var nested in ns.GetNamespaceMembers())
+ {
+ CollectExistingShouldEntryKeys(nested, keys);
+ }
+ }
+
+ private static void CollectExistingShouldEntryKeys(INamedTypeSymbol type, HashSet keys)
+ {
+ foreach (var nested in type.GetTypeMembers())
+ {
+ CollectExistingShouldEntryKeys(nested, keys);
+ }
+
+ if (!string.Equals(type.Name, "ShouldExtensions", StringComparison.Ordinal)
+ || type.DeclaredAccessibility != Accessibility.Public
+ || !type.IsStatic)
+ {
+ return;
+ }
+
+ var containingNamespace = type.ContainingNamespace?.ToDisplayString(NoGlobalFormat) ?? string.Empty;
+ if (!string.Equals(containingNamespace, "TUnit.Assertions.Should", StringComparison.Ordinal))
+ {
+ return;
+ }
+
+ foreach (var method in type.GetMembers("Should").OfType())
+ {
+ if (!method.IsExtensionMethod || method.Parameters.Length == 0)
+ {
+ continue;
+ }
+
+ keys.Add(CreateShouldMethodSignatureKey(method));
+ }
+ }
+
+ private static string CreateShouldMethodSignatureKey(IMethodSymbol method, string? methodNameOverride = null)
+ {
+ var typeParameterOrdinals = new Dictionary(SymbolEqualityComparer.Default);
+ for (var i = 0; i < method.TypeParameters.Length; i++)
+ {
+ typeParameterOrdinals[method.TypeParameters[i]] = i;
+ }
+
+ var sb = new StringBuilder(methodNameOverride ?? method.Name).Append('|');
+ AppendTypeSignatureKey(sb, method.Parameters[0].Type, typeParameterOrdinals);
+ sb.Append('|').Append(method.TypeParameters.Length);
+
+ foreach (var parameter in method.Parameters.Skip(1))
+ {
+ sb.Append('|');
+ AppendTypeSignatureKey(sb, parameter.Type, typeParameterOrdinals);
+ }
+
+ return sb.ToString();
+ }
+
+ private static void AppendTypeSignatureKey(
+ StringBuilder sb,
+ ITypeSymbol type,
+ Dictionary typeParameterOrdinals)
+ {
+ switch (type)
+ {
+ case ITypeParameterSymbol typeParameter:
+ if (typeParameterOrdinals.TryGetValue(typeParameter, out var ordinal))
+ {
+ sb.Append('!').Append(ordinal);
+ }
+ else
+ {
+ sb.Append('!').Append(typeParameter.Ordinal);
+ }
+ return;
+
+ case IArrayTypeSymbol arrayType:
+ AppendTypeSignatureKey(sb, arrayType.ElementType, typeParameterOrdinals);
+ sb.Append('[');
+ if (arrayType.Rank > 1)
+ {
+ sb.Append(',', arrayType.Rank - 1);
+ }
+ sb.Append(']');
+ return;
+
+ case IPointerTypeSymbol pointerType:
+ AppendTypeSignatureKey(sb, pointerType.PointedAtType, typeParameterOrdinals);
+ sb.Append('*');
+ return;
+
+ case INamedTypeSymbol namedType:
+ var originalDefinition = namedType.OriginalDefinition;
+
+ if (originalDefinition.ContainingType is not null)
+ {
+ AppendTypeSignatureKey(sb, originalDefinition.ContainingType, typeParameterOrdinals);
+ sb.Append('+');
+ }
+ else if (!originalDefinition.ContainingNamespace.IsGlobalNamespace)
+ {
+ sb.Append(originalDefinition.ContainingNamespace.ToDisplayString()).Append('.');
+ }
+
+ sb.Append(originalDefinition.MetadataName);
+ if (namedType.TypeArguments.Length > 0)
+ {
+ sb.Append('<');
+ for (var i = 0; i < namedType.TypeArguments.Length; i++)
+ {
+ if (i > 0)
+ {
+ sb.Append(',');
+ }
+
+ AppendTypeSignatureKey(sb, namedType.TypeArguments[i], typeParameterOrdinals);
+ }
+ sb.Append('>');
+ }
+ return;
+
+ default:
+ sb.Append(type.ToDisplayString(NoGlobalFormat));
+ return;
+ }
+ }
+
+ private static void CollectShouldEntries(
+ INamedTypeSymbol assertType,
+ INamedTypeSymbol assertionSource,
+ INamedTypeSymbol assertionBase,
+ INamedTypeSymbol assertionContext,
+ ImmutableArray.Builder builder)
+ {
+ foreach (var member in assertType.GetMembers("That").OfType())
+ {
+ if (TryDescribeShouldEntry(member, assertionSource, out var entry))
+ {
+ builder.Add(entry);
+ }
+ }
+ }
+
+ private static bool TryDescribeShouldEntry(
+ IMethodSymbol method,
+ INamedTypeSymbol assertionSource,
+ out ShouldEntryData data)
+ {
+ data = null!;
+
+ if (method.DeclaredAccessibility != Accessibility.Public
+ || !method.IsStatic
+ || method.Parameters.Length == 0
+ || method.ReturnType is not INamedTypeSymbol returnType)
+ {
+ return false;
+ }
+
+ if (!ImplementsAssertionSource(returnType, assertionSource, out var sourceTypeArg))
+ {
+ return false;
+ }
+
+ if (!ShouldGenerateEntryForReceiver(method.Parameters[0].Type))
+ {
+ return false;
+ }
+
+ var receiver = method.Parameters[0];
+ if (TryGetCallerArgumentExpressionTarget(method.Parameters[^1]) is null)
+ {
+ return false;
+ }
+
+ var paramData = ImmutableArray.CreateBuilder();
+ for (var i = 1; i < method.Parameters.Length; i++)
+ {
+ var p = method.Parameters[i];
+ var caeTarget = TryGetCallerArgumentExpressionTarget(p);
+ paramData.Add(new ParameterData(
+ Name: p.Name,
+ TypeName: p.Type.ToDisplayString(NoGlobalFormat),
+ HasDefaultValue: p.HasExplicitDefaultValue,
+ DefaultValueLiteral: p.HasExplicitDefaultValue ? FormatDefaultValue(p.ExplicitDefaultValue, p.Type) : null,
+ CallerArgumentExpressionTarget: caeTarget));
+ }
+
+ var genericParams = ImmutableArray.CreateBuilder();
+ foreach (var tp in method.TypeParameters)
+ {
+ genericParams.Add(GenericParamData.From(tp, NoGlobalFormat));
+ }
+
+ var priority = TryGetOverloadResolutionPriority(method.GetAttributes());
+
+ data = new ShouldEntryData(
+ ReceiverTypeName: receiver.Type.ToDisplayString(NoGlobalFormat),
+ SourceTypeArgDisplay: sourceTypeArg.ToDisplayString(NoGlobalFormat),
+ MethodGenericParams: new EquatableArray(genericParams),
+ Parameters: new EquatableArray(paramData),
+ SignatureKey: CreateShouldMethodSignatureKey(method, "Should"),
+ Priority: priority,
+ RequiresUnreferencedCodeMessage: TryGetRucMessage(method.GetAttributes())
+ ?? TryGetRucMessage(returnType.GetAttributes())
+ ?? TryGetRucMessageFromConstructors(returnType),
+ SuppressedTrimWarnings: new EquatableArray(CollectSuppressedTrimWarnings(method.GetAttributes())),
+ ForwardedAttributes: new EquatableArray(CollectForwardedAttributes(method.GetAttributes())));
+ return true;
+ }
+
+ private static bool ShouldGenerateEntryForReceiver(ITypeSymbol receiverType)
+ {
+ if (receiverType is IArrayTypeSymbol)
+ {
+ return true;
+ }
+
+ if (receiverType is not INamedTypeSymbol named)
+ {
+ return false;
+ }
+
+ return IsOriginalDefinition(named, "System.Collections.Generic.IReadOnlyDictionary`2")
+ || IsOriginalDefinition(named, "System.Collections.Generic.IDictionary`2")
+ || IsOriginalDefinition(named, "System.Collections.Generic.ISet`1")
+ || IsOriginalDefinition(named, "System.Collections.Generic.IReadOnlySet`1")
+ || IsOriginalDefinition(named, "System.Collections.Generic.IList`1")
+ || IsOriginalDefinition(named, "System.Collections.Generic.IReadOnlyList`1")
+ || IsOriginalDefinition(named, "System.Collections.Generic.HashSet`1")
+ || IsOriginalDefinition(named, "System.Collections.Generic.Dictionary`2")
+ || IsOriginalDefinition(named, "System.Memory`1")
+ || IsOriginalDefinition(named, "System.ReadOnlyMemory`1")
+ || IsOriginalDefinition(named, "System.Collections.Generic.IAsyncEnumerable`1")
+ || IsFuncReturningEnumerable(named)
+ || IsFuncReturningTaskOfEnumerable(named);
+ }
+
+ private static bool IsOriginalDefinition(INamedTypeSymbol type, string fullMetadataName)
+ {
+ var original = type.OriginalDefinition;
+ var ns = original.ContainingNamespace?.ToDisplayString();
+ return string.Equals($"{ns}.{original.MetadataName}", fullMetadataName, StringComparison.Ordinal);
+ }
+
+ private static bool IsFuncReturningEnumerable(INamedTypeSymbol type)
+ {
+ return IsOriginalDefinition(type, "System.Func`1")
+ && type.TypeArguments.Length == 1
+ && type.TypeArguments[0] is INamedTypeSymbol returnType
+ && IsOriginalDefinition(returnType, "System.Collections.Generic.IEnumerable`1");
+ }
+
+ private static bool IsFuncReturningTaskOfEnumerable(INamedTypeSymbol type)
+ {
+ return IsOriginalDefinition(type, "System.Func`1")
+ && type.TypeArguments.Length == 1
+ && type.TypeArguments[0] is INamedTypeSymbol taskType
+ && IsOriginalDefinition(taskType, "System.Threading.Tasks.Task`1")
+ && taskType.TypeArguments.Length == 1
+ && taskType.TypeArguments[0] is INamedTypeSymbol enumerableType
+ && IsOriginalDefinition(enumerableType, "System.Collections.Generic.IEnumerable`1");
+ }
+
+ private static bool ImplementsAssertionSource(
+ INamedTypeSymbol type,
+ INamedTypeSymbol assertionSource,
+ out ITypeSymbol sourceTypeArg)
+ {
+ if (IsAssertionSourceInterface(type, assertionSource))
+ {
+ sourceTypeArg = type.TypeArguments[0];
+ return true;
+ }
+
+ foreach (var iface in type.AllInterfaces)
+ {
+ if (IsAssertionSourceInterface(iface, assertionSource))
+ {
+ sourceTypeArg = iface.TypeArguments[0];
+ return true;
+ }
+ }
+
+ sourceTypeArg = null!;
+ return false;
+ }
+
+ private static int TryGetOverloadResolutionPriority(ImmutableArray attrs)
+ {
+ foreach (var attr in attrs)
+ {
+ if (attr.AttributeClass?.Name == "OverloadResolutionPriorityAttribute"
+ && attr.ConstructorArguments.Length == 1
+ && attr.ConstructorArguments[0].Value is int priority)
+ {
+ return priority;
+ }
+ }
+
+ return 0;
+ }
+
+ private static void EmitShouldEntries(SourceProductionContext ctx, ShouldEntryData[] entries, HashSet emittedHints)
+ {
+ var sb = new StringBuilder();
+ sb.AppendLine("// ");
+ sb.AppendLine("#nullable enable");
+ sb.AppendLine();
+ sb.AppendLine("using System;");
+ sb.AppendLine("using System.Runtime.CompilerServices;");
+ sb.AppendLine("using TUnit.Assertions;");
+ sb.AppendLine("using TUnit.Assertions.Should.Core;");
+ sb.AppendLine();
+ sb.AppendLine("namespace TUnit.Assertions.Should;");
+ sb.AppendLine();
+ sb.AppendLine("public static partial class ShouldExtensions");
+ sb.AppendLine("{");
+
+ foreach (var entry in entries)
+ {
+ EmitShouldEntry(sb, entry);
+ }
+
+ sb.AppendLine("}");
+
+ var hint = "ShouldExtensions.Generated.g.cs";
+ var suffix = 0;
+ while (!emittedHints.Add(hint))
+ {
+ hint = $"ShouldExtensions.Generated_{++suffix}.g.cs";
+ }
+
+ ctx.AddSource(hint, sb.ToString());
+ }
+
+ private static void EmitShouldEntry(StringBuilder sb, ShouldEntryData entry)
+ {
+ var genericList = entry.MethodGenericParams.Length > 0
+ ? "<" + string.Join(", ", entry.MethodGenericParams.Select(p =>
+ p.DynamicallyAccessedMembersAttribute is null
+ ? p.Name
+ : $"{p.DynamicallyAccessedMembersAttribute} {p.Name}")) + ">"
+ : string.Empty;
+
+ var constraints = string.Join(" ", entry.MethodGenericParams
+ .Select(p => p.ConstraintClause)
+ .Where(c => c is not null));
+
+ sb.AppendLine();
+ if (entry.Priority != 0)
+ {
+ sb.AppendLine($" [global::System.Runtime.CompilerServices.OverloadResolutionPriority({entry.Priority})]");
+ }
+ if (!string.IsNullOrEmpty(entry.RequiresUnreferencedCodeMessage))
+ {
+ var escaped = entry.RequiresUnreferencedCodeMessage!.Replace("\"", "\\\"");
+ sb.AppendLine($" [global::System.Diagnostics.CodeAnalysis.RequiresUnreferencedCode(\"{escaped}\")]");
+ }
+ foreach (var code in entry.SuppressedTrimWarnings)
+ {
+ sb.AppendLine($" [global::System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage(\"Trimming\", \"{code}\", Justification = \"Forwarded from source method\")]");
+ }
+ foreach (var attr in entry.ForwardedAttributes)
+ {
+ sb.AppendLine($" {attr}");
+ }
+
+ sb.Append($" public static global::TUnit.Assertions.Should.Core.ShouldSource<{entry.SourceTypeArgDisplay}> Should{genericList}(this {entry.ReceiverTypeName} value");
+ foreach (var p in entry.Parameters)
+ {
+ sb.Append($", {p.TypeName} {p.Name}");
+ if (p.HasDefaultValue)
+ {
+ sb.Append(" = ").Append(p.DefaultValueLiteral);
+ }
+ }
+ sb.Append(')');
+
+ if (!string.IsNullOrEmpty(constraints))
+ {
+ sb.AppendLine();
+ sb.Append(" ").Append(constraints);
+ }
+
+ sb.AppendLine();
+ sb.AppendLine(" {");
+ sb.Append(" var source = global::TUnit.Assertions.Assert.That").Append(genericList).Append("(value");
+ foreach (var p in entry.Parameters)
+ {
+ sb.Append(", ").Append(p.Name);
+ }
+ sb.AppendLine(");");
+ sb.AppendLine($" var innerContext = ((global::TUnit.Assertions.Core.IAssertionSource<{entry.SourceTypeArgDisplay}>)source).Context;");
+ sb.AppendLine(" innerContext.ExpressionBuilder.Clear();");
+
+ var expressionParam = entry.Parameters.LastOrDefault(p => p.CallerArgumentExpressionTarget == "value");
+ if (expressionParam is null)
+ {
+ sb.AppendLine(" innerContext.ExpressionBuilder.Append(\"?.Should()\");");
+ }
+ else
+ {
+ sb.AppendLine($" innerContext.ExpressionBuilder.Append({expressionParam.Name} ?? \"?\").Append(\".Should()\");");
+ }
+
+ sb.AppendLine($" return new global::TUnit.Assertions.Should.Core.ShouldSource<{entry.SourceTypeArgDisplay}>(innerContext);");
+ sb.AppendLine(" }");
+ }
+
private sealed record GeneratorPayload(
EquatableArray Methods,
- EquatableArray Wrappers);
+ EquatableArray Wrappers,
+ EquatableArray Entries);
+
+ private sealed record ShouldEntryData(
+ string ReceiverTypeName,
+ string SourceTypeArgDisplay,
+ EquatableArray MethodGenericParams,
+ EquatableArray Parameters,
+ string SignatureKey,
+ int Priority,
+ string? RequiresUnreferencedCodeMessage,
+ EquatableArray SuppressedTrimWarnings,
+ EquatableArray ForwardedAttributes);
private sealed record WrapperData(
string ContainingNamespace,
diff --git a/TUnit.Assertions.Should.Tests/CollectionTests.cs b/TUnit.Assertions.Should.Tests/CollectionTests.cs
index 8c1e215950..76a5989752 100644
--- a/TUnit.Assertions.Should.Tests/CollectionTests.cs
+++ b/TUnit.Assertions.Should.Tests/CollectionTests.cs
@@ -96,6 +96,32 @@ public async Task HaveCount()
await list.Should().HaveCount(3);
}
+ [Test]
+ public async Task Dictionary_ContainKeyWithValue()
+ {
+ IReadOnlyDictionary dict = new Dictionary
+ {
+ ["one"] = 1,
+ ["two"] = 2,
+ };
+
+ await dict.Should().ContainKeyWithValue("one", 1);
+ }
+
+ [Test]
+ public async Task HashSet_BeSupersetOf()
+ {
+ var set = new HashSet { "apple", "banana", "cherry" };
+ await set.Should().BeSupersetOf(["banana"]);
+ }
+
+ [Test]
+ public async Task Func_collection_HaveAtLeast()
+ {
+ Func func = () => [1, 2];
+ await func.Should().HaveAtLeast(2);
+ }
+
[Test]
public async Task Contain_predicate()
{
diff --git a/TUnit.Assertions.Should/Core/ShouldCollectionSource.cs b/TUnit.Assertions.Should/Core/ShouldCollectionSource.cs
index 717c1e41e9..5c15737c25 100644
--- a/TUnit.Assertions.Should/Core/ShouldCollectionSource.cs
+++ b/TUnit.Assertions.Should/Core/ShouldCollectionSource.cs
@@ -21,77 +21,22 @@ namespace TUnit.Assertions.Should.Core;
///
///
[ShouldGeneratePartial(typeof(CollectionAssertion<>))]
-public sealed partial class ShouldCollectionSource : IShouldSource>
+public sealed partial class ShouldCollectionSource : ShouldEnumerableSourceBase, TItem, ShouldCollectionSource>
{
- private string? _becauseMessage;
-
- public AssertionContext> Context { get; }
-
[EditorBrowsable(EditorBrowsableState.Never)]
public ShouldCollectionSource(IEnumerable? value, string? expression)
+ : base(new AssertionContext>(value, BuildExpression(expression)))
{
- var sb = new StringBuilder((expression?.Length ?? 1) + 16);
- sb.Append(expression ?? "?").Append(".Should()");
- Context = new AssertionContext>(value, sb);
}
- public ShouldCollectionSource Because(string message)
+ internal ShouldCollectionSource(AssertionContext> context)
+ : base(context)
{
- _becauseMessage = message.Trim();
- return this;
}
-
- string? IShouldSource>.ConsumeBecauseMessage()
- => ConsumeBecauseMessage();
-
- private string? ConsumeBecauseMessage()
+ private static StringBuilder BuildExpression(string? expression)
{
- var message = _becauseMessage;
- _becauseMessage = null;
- return message;
- }
-
- private TAssertion ApplyBecause(TAssertion assertion)
- where TAssertion : Assertion>
- {
- var because = ConsumeBecauseMessage();
- if (because is not null)
- {
- assertion.Because(because);
- }
- return assertion;
- }
-
- // The next three methods can't be source-generated by the simple-factory rule: their target
- // assertion ctors take a separate `predicateDescription` string (e.g.
- // CollectionAllAssertion(ctx, predicate, predicateDescription)) which the source-side method
- // fills from the CAE expression value with a literal fallback — a one-method-param-to-two-
- // ctor-params shape the generator's filter rejects.
-
- public ShouldAssertion> All(
- Func predicate,
- [CallerArgumentExpression(nameof(predicate))] string? expression = null)
- {
- Context.ExpressionBuilder.Append(".All(").Append(expression).Append(')');
- var inner = ApplyBecause(new CollectionAllAssertion, TItem>(Context, predicate, expression ?? "predicate"));
- return new ShouldAssertion>(Context, inner);
- }
-
- public ShouldAssertion> Any(
- Func predicate,
- [CallerArgumentExpression(nameof(predicate))] string? expression = null)
- {
- Context.ExpressionBuilder.Append(".Any(").Append(expression).Append(')');
- var inner = ApplyBecause(new CollectionAnyAssertion, TItem>(Context, predicate, expression ?? "predicate"));
- return new ShouldAssertion>(Context, inner);
- }
-
- public ShouldAssertion> HaveSingleItem(
- Func predicate,
- [CallerArgumentExpression(nameof(predicate))] string? expression = null)
- {
- Context.ExpressionBuilder.Append(".HaveSingleItem(").Append(expression).Append(')');
- var inner = ApplyBecause(new HasSingleItemPredicateAssertion, TItem>(Context, predicate, expression ?? "predicate"));
- return new ShouldAssertion>(Context, inner);
+ var sb = new StringBuilder((expression?.Length ?? 1) + 16);
+ sb.Append(expression ?? "?").Append(".Should()");
+ return sb;
}
}
diff --git a/TUnit.Assertions.Should/Core/ShouldDelegateCollectionSource.cs b/TUnit.Assertions.Should/Core/ShouldDelegateCollectionSource.cs
new file mode 100644
index 0000000000..5340841daf
--- /dev/null
+++ b/TUnit.Assertions.Should/Core/ShouldDelegateCollectionSource.cs
@@ -0,0 +1,146 @@
+using System.ComponentModel;
+using System.Runtime.CompilerServices;
+using System.Text;
+using TUnit.Assertions.Conditions;
+using TUnit.Assertions.Core;
+
+namespace TUnit.Assertions.Should.Core;
+
+public readonly struct ShouldDelegateCollectionSource : IShouldSource>
+{
+ private readonly string? _becauseMessage;
+
+ public AssertionContext> Context { get; }
+
+ [EditorBrowsable(EditorBrowsableState.Never)]
+ public ShouldDelegateCollectionSource(AssertionContext> context)
+ : this(context, becauseMessage: null)
+ {
+ }
+
+ private ShouldDelegateCollectionSource(AssertionContext> context, string? becauseMessage)
+ {
+ Context = context;
+ _becauseMessage = becauseMessage;
+ }
+
+ public ShouldDelegateCollectionSource Because(string message)
+ => new(Context, message.Trim());
+
+ string? IShouldSource>.ConsumeBecauseMessage()
+ => _becauseMessage;
+
+ public ShouldAssertion> HaveAtLeast(
+ int minCount,
+ [CallerArgumentExpression(nameof(minCount))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append($".HaveAtLeast({expression})");
+ var inner = new CollectionHasAtLeastAssertion, TItem>(Context, minCount);
+ ApplyBecause(inner);
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> HaveAtMost(
+ int maxCount,
+ [CallerArgumentExpression(nameof(maxCount))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append($".HaveAtMost({expression})");
+ var inner = new CollectionHasAtMostAssertion, TItem>(Context, maxCount);
+ ApplyBecause(inner);
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> HaveCount(
+ int expectedCount,
+ [CallerArgumentExpression(nameof(expectedCount))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append($".HaveCount({expression})");
+ var inner = new CollectionCountAssertion, TItem>(Context, expectedCount);
+ ApplyBecause(inner);
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion Throw() where TException : Exception
+ {
+ Context.ExpressionBuilder.Append($".Throw<{FormatTypeName(typeof(TException))}>()");
+ var mapped = Context.MapException();
+ var inner = new ThrowsAssertion(mapped);
+ ApplyBecause(inner);
+ return new ShouldAssertion(mapped, inner);
+ }
+
+ public ShouldAssertion ThrowExactly() where TException : Exception
+ {
+ Context.ExpressionBuilder.Append($".ThrowExactly<{FormatTypeName(typeof(TException))}>()");
+ var mapped = Context.MapException();
+ var inner = new ThrowsExactlyAssertion(mapped);
+ ApplyBecause(inner);
+ return new ShouldAssertion(mapped, inner);
+ }
+
+ private void ApplyBecause(Assertion assertion)
+ {
+ if (_becauseMessage is not null)
+ {
+ assertion.Because(_becauseMessage);
+ }
+ }
+
+ private static string FormatTypeName(Type t)
+ {
+ if (!t.IsGenericType)
+ {
+ return t.Name;
+ }
+
+ var name = t.Name;
+ var tickIndex = name.IndexOf('`');
+ if (tickIndex > 0)
+ {
+ name = name.Substring(0, tickIndex);
+ }
+
+ return $"{name}<{string.Join(", ", t.GenericTypeArguments.Select(FormatTypeName))}>";
+ }
+
+ internal static AssertionContext> CreateContext(Func?> func, string? expression)
+ {
+ var expressionBuilder = BuildExpression(expression);
+ var evaluationContext = new EvaluationContext>(() =>
+ {
+ try
+ {
+ return Task.FromResult<(IEnumerable?, Exception?)>((func(), null));
+ }
+ catch (Exception ex)
+ {
+ return Task.FromResult<(IEnumerable?, Exception?)>((default, ex));
+ }
+ });
+ return new AssertionContext>(evaluationContext, expressionBuilder);
+ }
+
+ internal static AssertionContext> CreateContext(Func?>> func, string? expression)
+ {
+ var expressionBuilder = BuildExpression(expression);
+ var evaluationContext = new EvaluationContext>(async () =>
+ {
+ try
+ {
+ return (await func().ConfigureAwait(false), null);
+ }
+ catch (Exception ex)
+ {
+ return (default, ex);
+ }
+ });
+ return new AssertionContext>(evaluationContext, expressionBuilder);
+ }
+
+ private static StringBuilder BuildExpression(string? expression)
+ {
+ var sb = new StringBuilder((expression?.Length ?? 1) + 16);
+ sb.Append(expression ?? "?").Append(".Should()");
+ return sb;
+ }
+}
diff --git a/TUnit.Assertions.Should/Core/ShouldDictionarySource.cs b/TUnit.Assertions.Should/Core/ShouldDictionarySource.cs
new file mode 100644
index 0000000000..e372264dfa
--- /dev/null
+++ b/TUnit.Assertions.Should/Core/ShouldDictionarySource.cs
@@ -0,0 +1,246 @@
+using System.Collections.Generic;
+using System.ComponentModel;
+using System.Runtime.CompilerServices;
+using System.Text;
+using TUnit.Assertions.Conditions;
+using TUnit.Assertions.Core;
+using TUnit.Assertions.Should.Attributes;
+using TUnit.Assertions.Sources;
+
+namespace TUnit.Assertions.Should.Core;
+
+[ShouldGeneratePartial(typeof(DictionaryAssertion<,>))]
+public sealed partial class ShouldDictionarySource
+ : ShouldEnumerableSourceBase, KeyValuePair, ShouldDictionarySource>
+ where TKey : notnull
+{
+ [EditorBrowsable(EditorBrowsableState.Never)]
+ public ShouldDictionarySource(IReadOnlyDictionary? value, string? expression)
+ : base(new AssertionContext>(value!, BuildExpression(expression)))
+ {
+ }
+
+ internal ShouldDictionarySource(AssertionContext> context)
+ : base(context)
+ {
+ }
+
+ public ShouldAssertion> ContainKey(
+ TKey expectedKey,
+ [CallerArgumentExpression(nameof(expectedKey))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".ContainKey(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryContainsKeyAssertion, TKey, TValue>(Context, expectedKey));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> ContainKey(
+ TKey expectedKey,
+ IEqualityComparer? comparer,
+ [CallerArgumentExpression(nameof(expectedKey))] string? keyExpression = null,
+ [CallerArgumentExpression(nameof(comparer))] string? comparerExpression = null)
+ {
+ Context.ExpressionBuilder.Append($".ContainKey({keyExpression}, {comparerExpression})");
+ var inner = ApplyBecause(new DictionaryContainsKeyAssertion, TKey, TValue>(Context, expectedKey, comparer));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> NotContainKey(
+ TKey expectedKey,
+ [CallerArgumentExpression(nameof(expectedKey))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".NotContainKey(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryDoesNotContainKeyAssertion, TKey, TValue>(Context, expectedKey));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> ContainValue(
+ TValue expectedValue,
+ [CallerArgumentExpression(nameof(expectedValue))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".ContainValue(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryContainsValueAssertion, TKey, TValue>(Context, expectedValue));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> NotContainValue(
+ TValue expectedValue,
+ [CallerArgumentExpression(nameof(expectedValue))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".NotContainValue(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryDoesNotContainValueAssertion, TKey, TValue>(Context, expectedValue));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> ContainKeyWithValue(
+ TKey expectedKey,
+ TValue expectedValue,
+ [CallerArgumentExpression(nameof(expectedKey))] string? keyExpression = null,
+ [CallerArgumentExpression(nameof(expectedValue))] string? valueExpression = null)
+ {
+ Context.ExpressionBuilder.Append($".ContainKeyWithValue({keyExpression}, {valueExpression})");
+ var inner = ApplyBecause(new DictionaryContainsKeyWithValueAssertion, TKey, TValue>(Context, expectedKey, expectedValue));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> AllKeys(
+ Func predicate,
+ [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".AllKeys(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryAllKeysAssertion, TKey, TValue>(Context, predicate));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> AllValues(
+ Func predicate,
+ [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".AllValues(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryAllValuesAssertion, TKey, TValue>(Context, predicate));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> AnyKey(
+ Func predicate,
+ [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".AnyKey(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryAnyKeyAssertion, TKey, TValue>(Context, predicate));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> AnyValue(
+ Func predicate,
+ [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".AnyValue(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryAnyValueAssertion, TKey, TValue>(Context, predicate));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ private static StringBuilder BuildExpression(string? expression)
+ {
+ var sb = new StringBuilder((expression?.Length ?? 1) + 16);
+ sb.Append(expression ?? "?").Append(".Should()");
+ return sb;
+ }
+}
+
+[ShouldGeneratePartial(typeof(MutableDictionaryAssertion<,>))]
+public sealed partial class ShouldMutableDictionarySource
+ : ShouldEnumerableSourceBase, KeyValuePair, ShouldMutableDictionarySource>
+ where TKey : notnull
+{
+ [EditorBrowsable(EditorBrowsableState.Never)]
+ public ShouldMutableDictionarySource(IDictionary? value, string? expression)
+ : base(new AssertionContext>(value!, BuildExpression(expression)))
+ {
+ }
+
+ internal ShouldMutableDictionarySource(AssertionContext> context)
+ : base(context)
+ {
+ }
+
+ public ShouldAssertion> ContainKey(
+ TKey expectedKey,
+ [CallerArgumentExpression(nameof(expectedKey))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".ContainKey(").Append(expression).Append(')');
+ var inner = ApplyBecause(new MutableDictionaryContainsKeyAssertion, TKey, TValue>(Context, expectedKey));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> ContainKey(
+ TKey expectedKey,
+ IEqualityComparer? comparer,
+ [CallerArgumentExpression(nameof(expectedKey))] string? keyExpression = null,
+ [CallerArgumentExpression(nameof(comparer))] string? comparerExpression = null)
+ {
+ Context.ExpressionBuilder.Append($".ContainKey({keyExpression}, {comparerExpression})");
+ var inner = ApplyBecause(new MutableDictionaryContainsKeyAssertion, TKey, TValue>(Context, expectedKey, comparer));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> NotContainKey(
+ TKey expectedKey,
+ [CallerArgumentExpression(nameof(expectedKey))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".NotContainKey(").Append(expression).Append(')');
+ var inner = ApplyBecause(new MutableDictionaryDoesNotContainKeyAssertion, TKey, TValue>(Context, expectedKey));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> ContainValue(
+ TValue expectedValue,
+ [CallerArgumentExpression(nameof(expectedValue))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".ContainValue(").Append(expression).Append(')');
+ var inner = ApplyBecause(new MutableDictionaryContainsValueAssertion, TKey, TValue>(Context, expectedValue));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> NotContainValue(
+ TValue expectedValue,
+ [CallerArgumentExpression(nameof(expectedValue))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".NotContainValue(").Append(expression).Append(')');
+ var inner = ApplyBecause(new MutableDictionaryDoesNotContainValueAssertion