diff --git a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet10_0.verified.txt b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet10_0.verified.txt
new file mode 100644
index 0000000000..712c765f98
--- /dev/null
+++ b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet10_0.verified.txt
@@ -0,0 +1,24 @@
+//
+#nullable enable
+
+using System;
+using System.Runtime.CompilerServices;
+using TUnit.Assertions;
+using TUnit.Assertions.Should.Core;
+
+namespace TUnit.Assertions.Should;
+
+public static partial class ShouldExtensions
+{
+
+ [global::System.Runtime.CompilerServices.OverloadResolutionPriority(2)]
+ public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.IReadOnlyDictionary? value, string? expression = default)
+ where TKey : notnull
+ {
+ var source = global::TUnit.Assertions.Assert.That(value, expression);
+ var innerContext = ((global::TUnit.Assertions.Core.IAssertionSource>)source).Context;
+ innerContext.ExpressionBuilder.Clear();
+ innerContext.ExpressionBuilder.Append(expression ?? "?").Append(".Should()");
+ return new global::TUnit.Assertions.Should.Core.ShouldSource>(innerContext);
+ }
+}
diff --git a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet8_0.verified.txt b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet8_0.verified.txt
new file mode 100644
index 0000000000..712c765f98
--- /dev/null
+++ b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet8_0.verified.txt
@@ -0,0 +1,24 @@
+//
+#nullable enable
+
+using System;
+using System.Runtime.CompilerServices;
+using TUnit.Assertions;
+using TUnit.Assertions.Should.Core;
+
+namespace TUnit.Assertions.Should;
+
+public static partial class ShouldExtensions
+{
+
+ [global::System.Runtime.CompilerServices.OverloadResolutionPriority(2)]
+ public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.IReadOnlyDictionary? value, string? expression = default)
+ where TKey : notnull
+ {
+ var source = global::TUnit.Assertions.Assert.That(value, expression);
+ var innerContext = ((global::TUnit.Assertions.Core.IAssertionSource>)source).Context;
+ innerContext.ExpressionBuilder.Clear();
+ innerContext.ExpressionBuilder.Append(expression ?? "?").Append(".Should()");
+ return new global::TUnit.Assertions.Should.Core.ShouldSource>(innerContext);
+ }
+}
diff --git a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet9_0.verified.txt b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet9_0.verified.txt
new file mode 100644
index 0000000000..712c765f98
--- /dev/null
+++ b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_dictionary_specialization_emits_matching_Should_overload.DotNet9_0.verified.txt
@@ -0,0 +1,24 @@
+//
+#nullable enable
+
+using System;
+using System.Runtime.CompilerServices;
+using TUnit.Assertions;
+using TUnit.Assertions.Should.Core;
+
+namespace TUnit.Assertions.Should;
+
+public static partial class ShouldExtensions
+{
+
+ [global::System.Runtime.CompilerServices.OverloadResolutionPriority(2)]
+ public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.IReadOnlyDictionary? value, string? expression = default)
+ where TKey : notnull
+ {
+ var source = global::TUnit.Assertions.Assert.That(value, expression);
+ var innerContext = ((global::TUnit.Assertions.Core.IAssertionSource>)source).Context;
+ innerContext.ExpressionBuilder.Clear();
+ innerContext.ExpressionBuilder.Append(expression ?? "?").Append(".Should()");
+ return new global::TUnit.Assertions.Should.Core.ShouldSource>(innerContext);
+ }
+}
diff --git a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet10_0.verified.txt b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet10_0.verified.txt
new file mode 100644
index 0000000000..25318db419
--- /dev/null
+++ b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet10_0.verified.txt
@@ -0,0 +1,23 @@
+//
+#nullable enable
+
+using System;
+using System.Runtime.CompilerServices;
+using TUnit.Assertions;
+using TUnit.Assertions.Should.Core;
+
+namespace TUnit.Assertions.Should;
+
+public static partial class ShouldExtensions
+{
+
+ [global::System.Runtime.CompilerServices.OverloadResolutionPriority(3)]
+ public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.ISet? value, string? expression = default)
+ {
+ var source = global::TUnit.Assertions.Assert.That(value, expression);
+ var innerContext = ((global::TUnit.Assertions.Core.IAssertionSource>)source).Context;
+ innerContext.ExpressionBuilder.Clear();
+ innerContext.ExpressionBuilder.Append(expression ?? "?").Append(".Should()");
+ return new global::TUnit.Assertions.Should.Core.ShouldSource>(innerContext);
+ }
+}
diff --git a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet8_0.verified.txt b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet8_0.verified.txt
new file mode 100644
index 0000000000..25318db419
--- /dev/null
+++ b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet8_0.verified.txt
@@ -0,0 +1,23 @@
+//
+#nullable enable
+
+using System;
+using System.Runtime.CompilerServices;
+using TUnit.Assertions;
+using TUnit.Assertions.Should.Core;
+
+namespace TUnit.Assertions.Should;
+
+public static partial class ShouldExtensions
+{
+
+ [global::System.Runtime.CompilerServices.OverloadResolutionPriority(3)]
+ public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.ISet? value, string? expression = default)
+ {
+ var source = global::TUnit.Assertions.Assert.That(value, expression);
+ var innerContext = ((global::TUnit.Assertions.Core.IAssertionSource>)source).Context;
+ innerContext.ExpressionBuilder.Clear();
+ innerContext.ExpressionBuilder.Append(expression ?? "?").Append(".Should()");
+ return new global::TUnit.Assertions.Should.Core.ShouldSource>(innerContext);
+ }
+}
diff --git a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet9_0.verified.txt b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet9_0.verified.txt
new file mode 100644
index 0000000000..25318db419
--- /dev/null
+++ b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.Assert_That_set_specialization_emits_matching_Should_overload.DotNet9_0.verified.txt
@@ -0,0 +1,23 @@
+//
+#nullable enable
+
+using System;
+using System.Runtime.CompilerServices;
+using TUnit.Assertions;
+using TUnit.Assertions.Should.Core;
+
+namespace TUnit.Assertions.Should;
+
+public static partial class ShouldExtensions
+{
+
+ [global::System.Runtime.CompilerServices.OverloadResolutionPriority(3)]
+ public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.ISet? value, string? expression = default)
+ {
+ var source = global::TUnit.Assertions.Assert.That(value, expression);
+ var innerContext = ((global::TUnit.Assertions.Core.IAssertionSource>)source).Context;
+ innerContext.ExpressionBuilder.Clear();
+ innerContext.ExpressionBuilder.Append(expression ?? "?").Append(".Should()");
+ return new global::TUnit.Assertions.Should.Core.ShouldSource>(innerContext);
+ }
+}
diff --git a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.cs b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.cs
index 693db3a3c7..959664a4ca 100644
--- a/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.cs
+++ b/TUnit.Assertions.Should.SourceGenerator.Tests/ShouldExtensionGeneratorTests.cs
@@ -240,6 +240,84 @@ public static MyBetweenAssertion IsBetween(
await Assert.That(output).Contains("CallerArgumentExpression(\"max\")");
}
+ [Test]
+ public async Task Assert_That_dictionary_specialization_emits_matching_Should_overload()
+ {
+ var output = await RunGenerator("""
+ using System.Collections.Generic;
+ using System.Runtime.CompilerServices;
+ using TUnit.Assertions.Core;
+
+ namespace TUnit.Assertions;
+
+ public class DictionaryAssertion : IAssertionSource>
+ where TKey : notnull
+ {
+ public AssertionContext> Context { get; }
+ public DictionaryAssertion(IReadOnlyDictionary? value, string? expression)
+ => Context = new AssertionContext>(value!, new System.Text.StringBuilder());
+ public TypeOfAssertion, TExpected> IsTypeOf() => throw new System.NotImplementedException();
+ public IsNotTypeOfAssertion, TExpected> IsNotTypeOf() => throw new System.NotImplementedException();
+ public IsAssignableToAssertion> IsAssignableTo() => throw new System.NotImplementedException();
+ public IsNotAssignableToAssertion> IsNotAssignableTo() => throw new System.NotImplementedException();
+ public IsAssignableFromAssertion> IsAssignableFrom() => throw new System.NotImplementedException();
+ public IsNotAssignableFromAssertion> IsNotAssignableFrom() => throw new System.NotImplementedException();
+ }
+
+ public static class Assert
+ {
+ [OverloadResolutionPriority(2)]
+ public static DictionaryAssertion That(
+ IReadOnlyDictionary? value,
+ [CallerArgumentExpression(nameof(value))] string? expression = null)
+ where TKey : notnull
+ => new(value, expression);
+ }
+ """);
+
+ await Assert.That(output).Contains("public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.IReadOnlyDictionary? value");
+ await Assert.That(output).Contains("Assert.That(value, expression)");
+ await Assert.That(output).Contains("Append(expression ?? \"?\").Append(\".Should()\")");
+ }
+
+ [Test]
+ public async Task Assert_That_set_specialization_emits_matching_Should_overload()
+ {
+ var output = await RunGenerator("""
+ using System.Collections.Generic;
+ using System.Runtime.CompilerServices;
+ using TUnit.Assertions.Core;
+
+ namespace TUnit.Assertions;
+
+ public class SetAssertion : IAssertionSource>
+ {
+ public AssertionContext> Context { get; }
+ public SetAssertion(ISet? value, string? expression)
+ => Context = new AssertionContext>(value!, new System.Text.StringBuilder());
+ public TypeOfAssertion, TExpected> IsTypeOf() => throw new System.NotImplementedException();
+ public IsNotTypeOfAssertion, TExpected> IsNotTypeOf() => throw new System.NotImplementedException();
+ public IsAssignableToAssertion> IsAssignableTo() => throw new System.NotImplementedException();
+ public IsNotAssignableToAssertion> IsNotAssignableTo() => throw new System.NotImplementedException();
+ public IsAssignableFromAssertion> IsAssignableFrom() => throw new System.NotImplementedException();
+ public IsNotAssignableFromAssertion> IsNotAssignableFrom() => throw new System.NotImplementedException();
+ }
+
+ public static class Assert
+ {
+ [OverloadResolutionPriority(3)]
+ public static SetAssertion That(
+ ISet? value,
+ [CallerArgumentExpression(nameof(value))] string? expression = null)
+ => new(value, expression);
+ }
+ """);
+
+ await Assert.That(output).Contains("public static global::TUnit.Assertions.Should.Core.ShouldSource> Should(this System.Collections.Generic.ISet? value");
+ await Assert.That(output).Contains("OverloadResolutionPriority(3)");
+ await Assert.That(output).Contains("Assert.That(value, expression)");
+ }
+
///
/// Compiles with the Should-generator's input dependencies,
/// runs , snapshots the full generated source via
@@ -251,9 +329,26 @@ public static MyBetweenAssertion IsBetween(
///
private static async Task RunGenerator(string userSource, [CallerMemberName] string testName = "")
{
+ // On net8.0 hosts, OverloadResolutionPriorityAttribute is missing from the BCL and the
+ // Polyfill copy compiled into this test assembly is internal — invisible to the synthetic
+ // GeneratorTest compilation. Inject a public copy so the attribute resolves consistently
+ // across all TFMs the test multi-targets.
+ var inputTrees = new List { CSharpSyntaxTree.ParseText(userSource) };
+#if NET8_0
+ inputTrees.Add(CSharpSyntaxTree.ParseText("""
+ namespace System.Runtime.CompilerServices;
+ [System.AttributeUsage(System.AttributeTargets.Method | System.AttributeTargets.Constructor | System.AttributeTargets.Property, Inherited = false)]
+ public sealed class OverloadResolutionPriorityAttribute : System.Attribute
+ {
+ public OverloadResolutionPriorityAttribute(int priority) => Priority = priority;
+ public int Priority { get; }
+ }
+ """));
+#endif
+
var compilation = CSharpCompilation.Create(
assemblyName: "GeneratorTest",
- syntaxTrees: [CSharpSyntaxTree.ParseText(userSource)],
+ syntaxTrees: inputTrees,
references: GetReferences(),
options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
@@ -264,11 +359,11 @@ private static async Task RunGenerator(string userSource, [CallerMemberN
await Assert.That(diagnostics.Length).IsEqualTo(0)
.Because("Generator should not emit diagnostics for valid input");
- var trees = updatedCompilation.SyntaxTrees
- .Where(t => t != compilation.SyntaxTrees[0])
+ var generatedTrees = updatedCompilation.SyntaxTrees
+ .Where(t => !compilation.SyntaxTrees.Contains(t))
.Select(t => t.ToString());
- var combined = string.Join("\n//------\n", trees);
+ var combined = string.Join("\n//------\n", generatedTrees);
await Verify(combined)
.UseFileName($"{nameof(ShouldExtensionGeneratorTests)}.{testName}")
diff --git a/TUnit.Assertions.Should.SourceGenerator/ShouldExtensionGenerator.cs b/TUnit.Assertions.Should.SourceGenerator/ShouldExtensionGenerator.cs
index ad54707268..8b06374577 100644
--- a/TUnit.Assertions.Should.SourceGenerator/ShouldExtensionGenerator.cs
+++ b/TUnit.Assertions.Should.SourceGenerator/ShouldExtensionGenerator.cs
@@ -30,6 +30,8 @@ public sealed class ShouldExtensionGenerator : IIncrementalGenerator
private const string AssertionSourceFullName = "TUnit.Assertions.Core.IAssertionSource`1";
private const string AssertionBaseFullName = "TUnit.Assertions.Core.Assertion`1";
private const string AssertionContextFullName = "TUnit.Assertions.Core.AssertionContext`1";
+ private const string AssertFullName = "TUnit.Assertions.Assert";
+ private const string CallerArgumentExpressionAttributeFullName = "System.Runtime.CompilerServices.CallerArgumentExpressionAttribute";
private const string ShouldExtensionsNamespace = "TUnit.Assertions.Should.Extensions";
private const string ShouldNameAttributeFullName = "TUnit.Assertions.Should.Attributes.ShouldNameAttribute";
private const string CallerArgumentExpressionAttributeName = "CallerArgumentExpressionAttribute";
@@ -63,6 +65,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
{
var emittedHints = new HashSet(StringComparer.Ordinal);
+ if (payload.Entries.Length > 0)
+ {
+ EmitShouldEntries(ctx, payload.Entries.ToArray(), emittedHints);
+ }
+
// Wrappers first: they own the return types they cover, and their method names
// win over extension methods at call sites anyway.
foreach (var wrapper in payload.Wrappers)
@@ -101,7 +108,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
private sealed record ReferenceData(
EquatableArray Methods,
EquatableArray Wrappers,
- EquatableArray AlreadyBakedNames);
+ EquatableArray Entries,
+ EquatableArray AlreadyBakedNames,
+ EquatableArray BakedShouldEntryKeys);
private static GeneratorPayload Collect(Compilation compilation)
{
@@ -110,12 +119,14 @@ private static GeneratorPayload Collect(Compilation compilation)
var assertionContext = compilation.GetTypeByMetadataName(AssertionContextFullName);
var shouldNameAttr = compilation.GetTypeByMetadataName(ShouldNameAttributeFullName);
var partialMarker = compilation.GetTypeByMetadataName(ShouldGeneratePartialAttributeFullName);
+ var assertType = compilation.GetTypeByMetadataName(AssertFullName);
if (assertionSource is null || assertionBase is null || assertionContext is null)
{
return new GeneratorPayload(
new EquatableArray(Array.Empty()),
- new EquatableArray(Array.Empty()));
+ new EquatableArray(Array.Empty()),
+ new EquatableArray(Array.Empty()));
}
// Phase 1 — per-reference scan, cached by MetadataReference identity.
@@ -131,7 +142,7 @@ private static GeneratorPayload Collect(Compilation compilation)
continue;
}
refResults.Add(GetOrComputeReferenceData(
- compilation, refAssembly, assertionSource, assertionBase, assertionContext, shouldNameAttr, partialMarker));
+ compilation, refAssembly, assertionSource, assertionBase, assertionContext, shouldNameAttr, partialMarker, assertType));
}
// Phase 2 — union dedup sets across all references.
@@ -162,6 +173,24 @@ private static GeneratorPayload Collect(Compilation compilation)
WalkForWrappers(compilation.Assembly.GlobalNamespace, partialMarker, assertionBase, assertionContext, localWrappers, isCurrentAssembly: true);
}
+ var localEntries = ImmutableArray.CreateBuilder();
+ if (assertType is not null && SymbolEqualityComparer.Default.Equals(assertType.ContainingAssembly, compilation.Assembly))
+ {
+ CollectShouldEntries(assertType, assertionSource, assertionBase, assertionContext, localEntries);
+ }
+
+ var currentShouldEntryKeys = new HashSet(StringComparer.Ordinal);
+ CollectExistingShouldEntryKeys(compilation.Assembly.GlobalNamespace, currentShouldEntryKeys);
+
+ var bakedReferencedShouldEntryKeys = new HashSet(StringComparer.Ordinal);
+ foreach (var r in refResults)
+ {
+ foreach (var bakedKey in r.BakedShouldEntryKeys)
+ {
+ bakedReferencedShouldEntryKeys.Add(bakedKey);
+ }
+ }
+
// Phase 4 — merge and apply post-walk dedup. Wrapper instance methods and Should-flavored
// extensions co-exist by design: the wrapper's [ShouldGeneratePartial] only emits methods
// whose source overload exactly matches a public ctor on the inner assertion (the simple-
@@ -186,9 +215,31 @@ private static GeneratorPayload Collect(Compilation compilation)
}
}
+ var allEntries = ImmutableArray.CreateBuilder();
+ foreach (var entry in localEntries)
+ {
+ if (!currentShouldEntryKeys.Contains(entry.SignatureKey))
+ {
+ allEntries.Add(entry);
+ }
+ }
+ foreach (var r in refResults)
+ {
+ foreach (var entry in r.Entries)
+ {
+ var signatureKey = entry.SignatureKey;
+ if (!currentShouldEntryKeys.Contains(signatureKey)
+ && !bakedReferencedShouldEntryKeys.Contains(signatureKey))
+ {
+ allEntries.Add(entry);
+ }
+ }
+ }
+
return new GeneratorPayload(
- new EquatableArray(allMethods.ToArray()),
- new EquatableArray(localWrappers.ToArray()));
+ new EquatableArray(DeduplicateMethods(allMethods.ToArray())),
+ new EquatableArray(localWrappers.ToArray()),
+ new EquatableArray(DeduplicateEntries(allEntries.ToArray())));
}
///
@@ -204,12 +255,13 @@ private static ReferenceData GetOrComputeReferenceData(
INamedTypeSymbol assertionBase,
INamedTypeSymbol assertionContext,
INamedTypeSymbol? shouldNameAttr,
- INamedTypeSymbol? partialMarker)
+ INamedTypeSymbol? partialMarker,
+ INamedTypeSymbol? assertType)
{
var metadataRef = compilation.GetMetadataReference(refAssembly);
if (metadataRef is null)
{
- return ScanReference(refAssembly, compilation, assertionSource, assertionBase, assertionContext, shouldNameAttr, partialMarker);
+ return ScanReference(refAssembly, compilation, assertionSource, assertionBase, assertionContext, shouldNameAttr, partialMarker, assertType);
}
if (s_referenceCache.TryGetValue(metadataRef, out var cached))
@@ -217,7 +269,7 @@ private static ReferenceData GetOrComputeReferenceData(
return cached;
}
- var fresh = ScanReference(refAssembly, compilation, assertionSource, assertionBase, assertionContext, shouldNameAttr, partialMarker);
+ var fresh = ScanReference(refAssembly, compilation, assertionSource, assertionBase, assertionContext, shouldNameAttr, partialMarker, assertType);
// Concurrent races between two compilations seeing the same uncached MetadataReference
// are harmless — both compute the same ReferenceData; first writer wins. Catching
@@ -240,7 +292,8 @@ private static ReferenceData ScanReference(
INamedTypeSymbol assertionBase,
INamedTypeSymbol assertionContext,
INamedTypeSymbol? shouldNameAttr,
- INamedTypeSymbol? partialMarker)
+ INamedTypeSymbol? partialMarker,
+ INamedTypeSymbol? assertType)
{
var methods = ImmutableArray.CreateBuilder();
var ctx = new CollectionContext(
@@ -259,6 +312,12 @@ private static ReferenceData ScanReference(
WalkForWrappers(refAssembly.GlobalNamespace, partialMarker, assertionBase, assertionContext, wrappers, isCurrentAssembly: false);
}
+ var entries = ImmutableArray.CreateBuilder();
+ if (assertType is not null && SymbolEqualityComparer.Default.Equals(assertType.ContainingAssembly, refAssembly))
+ {
+ CollectShouldEntries(assertType, assertionSource, assertionBase, assertionContext, entries);
+ }
+
var bakedNames = new List();
var bakedNs = LookupNamespace(refAssembly.GlobalNamespace, ShouldExtensionsNamespace);
if (bakedNs is not null)
@@ -269,10 +328,15 @@ private static ReferenceData ScanReference(
}
}
+ var bakedShouldEntryKeys = new HashSet(StringComparer.Ordinal);
+ CollectExistingShouldEntryKeys(refAssembly.GlobalNamespace, bakedShouldEntryKeys);
+
return new ReferenceData(
- new EquatableArray(methods.ToArray()),
+ new EquatableArray(DeduplicateMethods(methods.ToArray())),
new EquatableArray(wrappers.ToArray()),
- new EquatableArray(bakedNames.ToArray()));
+ new EquatableArray(entries.ToArray()),
+ new EquatableArray(bakedNames.ToArray()),
+ new EquatableArray(bakedShouldEntryKeys.ToArray()));
}
private static void WalkForWrappers(
@@ -338,11 +402,17 @@ private static void CollectWrapper(
// The IsCurrentAssembly flag on WrapperData controls whether the emission step actually
// generates partial methods (only true for wrappers in this compilation).
var methods = ImmutableArray.CreateBuilder();
+ var existingWrapperMethodKeys = isCurrentAssembly
+ ? CollectExistingWrapperMethodKeys(type)
+ : null;
foreach (var sourceMember in EnumerateInstanceMethods(wrappedType))
{
if (TryDescribeWrapperMethod(sourceMember, wrappedAssertionTypeArg, assertionBase, assertionContext, out var data))
{
- methods.Add(data);
+ if (existingWrapperMethodKeys is null || existingWrapperMethodKeys.Add(GetWrapperMethodKey(data)))
+ {
+ methods.Add(data);
+ }
}
}
@@ -481,6 +551,53 @@ private static bool TryDescribeWrapperMethod(
return true;
}
+ private static HashSet CollectExistingWrapperMethodKeys(INamedTypeSymbol type)
+ {
+ var keys = new HashSet(StringComparer.Ordinal);
+ foreach (var method in type.GetMembers().OfType())
+ {
+ if (method.MethodKind != MethodKind.Ordinary || method.IsStatic)
+ {
+ continue;
+ }
+
+ keys.Add(GetWrapperMethodKey(method));
+ }
+
+ return keys;
+ }
+
+ private static string GetWrapperMethodKey(WrapperMethodData method)
+ {
+ // Type-parameter count is always 0 here: TryDescribeWrapperMethod rejects methods with
+ // method-level type parameters, so WrapperMethodData never carries them. Keep the literal
+ // in lockstep with the IMethodSymbol overload's "method.TypeParameters.Length" segment.
+ var sb = new StringBuilder(NameConjugator.Conjugate(method.SourceMethodName))
+ .Append('|')
+ .Append('0');
+
+ foreach (var parameter in method.Parameters)
+ {
+ sb.Append('|').Append(parameter.TypeName);
+ }
+
+ return sb.ToString();
+ }
+
+ private static string GetWrapperMethodKey(IMethodSymbol method)
+ {
+ var sb = new StringBuilder(method.Name)
+ .Append('|')
+ .Append(method.TypeParameters.Length);
+
+ foreach (var parameter in method.Parameters)
+ {
+ sb.Append('|').Append(parameter.Type.ToDisplayString(NoGlobalFormat));
+ }
+
+ return sb.ToString();
+ }
+
///
/// Pre-filter: returns true only when the reference (or one of its direct module references)
/// is itself. The check is one-level deep, NOT transitive — a
@@ -1106,9 +1223,499 @@ p.DynamicallyAccessedMembersAttribute is null
private static string FormatGenericArgs(EquatableArray args)
=> args.Length == 0 ? string.Empty : "<" + string.Join(", ", args) + ">";
+ private static MethodData[] DeduplicateMethods(MethodData[] methods)
+ {
+ var seen = new HashSet(StringComparer.Ordinal);
+ var result = new List(methods.Length);
+ foreach (var method in methods)
+ {
+ if (seen.Add(GetMethodKey(method)))
+ {
+ result.Add(method);
+ }
+ }
+
+ return result.ToArray();
+ }
+
+ private static string GetMethodKey(MethodData method)
+ {
+ var sb = new StringBuilder(method.ContainerName)
+ .Append('|')
+ .Append(method.MethodName)
+ .Append('|')
+ .Append(method.SourceTypeArgDisplay)
+ .Append('|')
+ .Append(method.AssertionTypeArgDisplay)
+ .Append('|')
+ .Append(method.ReturnTypeFullName)
+ .Append('|')
+ .Append(method.MethodGenericParams.Length);
+
+ foreach (var typeArg in method.ReturnTypeGenericArgs)
+ {
+ sb.Append('|').Append(typeArg);
+ }
+
+ foreach (var parameter in method.Parameters)
+ {
+ sb.Append('|')
+ .Append(parameter.TypeName)
+ .Append(':')
+ .Append(parameter.Name)
+ .Append(':')
+ .Append(parameter.CallerArgumentExpressionTarget);
+ }
+
+ return sb.ToString();
+ }
+
+ private static ShouldEntryData[] DeduplicateEntries(ShouldEntryData[] entries)
+ {
+ var seen = new HashSet(StringComparer.Ordinal);
+ var result = new List(entries.Length);
+ foreach (var entry in entries)
+ {
+ if (seen.Add(entry.SignatureKey))
+ {
+ result.Add(entry);
+ }
+ }
+
+ return result.ToArray();
+ }
+
+ private static void CollectExistingShouldEntryKeys(INamespaceSymbol ns, HashSet keys)
+ {
+ foreach (var type in ns.GetTypeMembers())
+ {
+ CollectExistingShouldEntryKeys(type, keys);
+ }
+
+ foreach (var nested in ns.GetNamespaceMembers())
+ {
+ CollectExistingShouldEntryKeys(nested, keys);
+ }
+ }
+
+ private static void CollectExistingShouldEntryKeys(INamedTypeSymbol type, HashSet keys)
+ {
+ foreach (var nested in type.GetTypeMembers())
+ {
+ CollectExistingShouldEntryKeys(nested, keys);
+ }
+
+ if (!string.Equals(type.Name, "ShouldExtensions", StringComparison.Ordinal)
+ || type.DeclaredAccessibility != Accessibility.Public
+ || !type.IsStatic)
+ {
+ return;
+ }
+
+ var containingNamespace = type.ContainingNamespace?.ToDisplayString(NoGlobalFormat) ?? string.Empty;
+ if (!string.Equals(containingNamespace, "TUnit.Assertions.Should", StringComparison.Ordinal))
+ {
+ return;
+ }
+
+ foreach (var method in type.GetMembers("Should").OfType())
+ {
+ if (!method.IsExtensionMethod || method.Parameters.Length == 0)
+ {
+ continue;
+ }
+
+ keys.Add(CreateShouldMethodSignatureKey(method));
+ }
+ }
+
+ private static string CreateShouldMethodSignatureKey(IMethodSymbol method, string? methodNameOverride = null)
+ {
+ var typeParameterOrdinals = new Dictionary(SymbolEqualityComparer.Default);
+ for (var i = 0; i < method.TypeParameters.Length; i++)
+ {
+ typeParameterOrdinals[method.TypeParameters[i]] = i;
+ }
+
+ var sb = new StringBuilder(methodNameOverride ?? method.Name).Append('|');
+ AppendTypeSignatureKey(sb, method.Parameters[0].Type, typeParameterOrdinals);
+ sb.Append('|').Append(method.TypeParameters.Length);
+
+ foreach (var parameter in method.Parameters.Skip(1))
+ {
+ sb.Append('|');
+ AppendTypeSignatureKey(sb, parameter.Type, typeParameterOrdinals);
+ }
+
+ return sb.ToString();
+ }
+
+ private static void AppendTypeSignatureKey(
+ StringBuilder sb,
+ ITypeSymbol type,
+ Dictionary typeParameterOrdinals)
+ {
+ switch (type)
+ {
+ case ITypeParameterSymbol typeParameter:
+ if (typeParameterOrdinals.TryGetValue(typeParameter, out var ordinal))
+ {
+ sb.Append('!').Append(ordinal);
+ }
+ else
+ {
+ sb.Append('!').Append(typeParameter.Ordinal);
+ }
+ return;
+
+ case IArrayTypeSymbol arrayType:
+ AppendTypeSignatureKey(sb, arrayType.ElementType, typeParameterOrdinals);
+ sb.Append('[');
+ if (arrayType.Rank > 1)
+ {
+ sb.Append(',', arrayType.Rank - 1);
+ }
+ sb.Append(']');
+ return;
+
+ case IPointerTypeSymbol pointerType:
+ AppendTypeSignatureKey(sb, pointerType.PointedAtType, typeParameterOrdinals);
+ sb.Append('*');
+ return;
+
+ case INamedTypeSymbol namedType:
+ var originalDefinition = namedType.OriginalDefinition;
+
+ if (originalDefinition.ContainingType is not null)
+ {
+ AppendTypeSignatureKey(sb, originalDefinition.ContainingType, typeParameterOrdinals);
+ sb.Append('+');
+ }
+ else if (!originalDefinition.ContainingNamespace.IsGlobalNamespace)
+ {
+ sb.Append(originalDefinition.ContainingNamespace.ToDisplayString()).Append('.');
+ }
+
+ sb.Append(originalDefinition.MetadataName);
+ if (namedType.TypeArguments.Length > 0)
+ {
+ sb.Append('<');
+ for (var i = 0; i < namedType.TypeArguments.Length; i++)
+ {
+ if (i > 0)
+ {
+ sb.Append(',');
+ }
+
+ AppendTypeSignatureKey(sb, namedType.TypeArguments[i], typeParameterOrdinals);
+ }
+ sb.Append('>');
+ }
+ return;
+
+ default:
+ sb.Append(type.ToDisplayString(NoGlobalFormat));
+ return;
+ }
+ }
+
+ private static void CollectShouldEntries(
+ INamedTypeSymbol assertType,
+ INamedTypeSymbol assertionSource,
+ INamedTypeSymbol assertionBase,
+ INamedTypeSymbol assertionContext,
+ ImmutableArray.Builder builder)
+ {
+ foreach (var member in assertType.GetMembers("That").OfType())
+ {
+ if (TryDescribeShouldEntry(member, assertionSource, out var entry))
+ {
+ builder.Add(entry);
+ }
+ }
+ }
+
+ private static bool TryDescribeShouldEntry(
+ IMethodSymbol method,
+ INamedTypeSymbol assertionSource,
+ out ShouldEntryData data)
+ {
+ data = null!;
+
+ if (method.DeclaredAccessibility != Accessibility.Public
+ || !method.IsStatic
+ || method.Parameters.Length == 0
+ || method.ReturnType is not INamedTypeSymbol returnType)
+ {
+ return false;
+ }
+
+ if (!ImplementsAssertionSource(returnType, assertionSource, out var sourceTypeArg))
+ {
+ return false;
+ }
+
+ if (!ShouldGenerateEntryForReceiver(method.Parameters[0].Type))
+ {
+ return false;
+ }
+
+ var receiver = method.Parameters[0];
+ if (TryGetCallerArgumentExpressionTarget(method.Parameters[^1]) is null)
+ {
+ return false;
+ }
+
+ var paramData = ImmutableArray.CreateBuilder();
+ for (var i = 1; i < method.Parameters.Length; i++)
+ {
+ var p = method.Parameters[i];
+ var caeTarget = TryGetCallerArgumentExpressionTarget(p);
+ paramData.Add(new ParameterData(
+ Name: p.Name,
+ TypeName: p.Type.ToDisplayString(NoGlobalFormat),
+ HasDefaultValue: p.HasExplicitDefaultValue,
+ DefaultValueLiteral: p.HasExplicitDefaultValue ? FormatDefaultValue(p.ExplicitDefaultValue, p.Type) : null,
+ CallerArgumentExpressionTarget: caeTarget));
+ }
+
+ var genericParams = ImmutableArray.CreateBuilder();
+ foreach (var tp in method.TypeParameters)
+ {
+ genericParams.Add(GenericParamData.From(tp, NoGlobalFormat));
+ }
+
+ var priority = TryGetOverloadResolutionPriority(method.GetAttributes());
+
+ data = new ShouldEntryData(
+ ReceiverTypeName: receiver.Type.ToDisplayString(NoGlobalFormat),
+ SourceTypeArgDisplay: sourceTypeArg.ToDisplayString(NoGlobalFormat),
+ MethodGenericParams: new EquatableArray(genericParams),
+ Parameters: new EquatableArray(paramData),
+ SignatureKey: CreateShouldMethodSignatureKey(method, "Should"),
+ Priority: priority,
+ RequiresUnreferencedCodeMessage: TryGetRucMessage(method.GetAttributes())
+ ?? TryGetRucMessage(returnType.GetAttributes())
+ ?? TryGetRucMessageFromConstructors(returnType),
+ SuppressedTrimWarnings: new EquatableArray(CollectSuppressedTrimWarnings(method.GetAttributes())),
+ ForwardedAttributes: new EquatableArray(CollectForwardedAttributes(method.GetAttributes())));
+ return true;
+ }
+
+ private static bool ShouldGenerateEntryForReceiver(ITypeSymbol receiverType)
+ {
+ if (receiverType is IArrayTypeSymbol)
+ {
+ return true;
+ }
+
+ if (receiverType is not INamedTypeSymbol named)
+ {
+ return false;
+ }
+
+ return IsOriginalDefinition(named, "System.Collections.Generic.IReadOnlyDictionary`2")
+ || IsOriginalDefinition(named, "System.Collections.Generic.IDictionary`2")
+ || IsOriginalDefinition(named, "System.Collections.Generic.ISet`1")
+ || IsOriginalDefinition(named, "System.Collections.Generic.IReadOnlySet`1")
+ || IsOriginalDefinition(named, "System.Collections.Generic.IList`1")
+ || IsOriginalDefinition(named, "System.Collections.Generic.IReadOnlyList`1")
+ || IsOriginalDefinition(named, "System.Collections.Generic.HashSet`1")
+ || IsOriginalDefinition(named, "System.Collections.Generic.Dictionary`2")
+ || IsOriginalDefinition(named, "System.Memory`1")
+ || IsOriginalDefinition(named, "System.ReadOnlyMemory`1")
+ || IsOriginalDefinition(named, "System.Collections.Generic.IAsyncEnumerable`1")
+ || IsFuncReturningEnumerable(named)
+ || IsFuncReturningTaskOfEnumerable(named);
+ }
+
+ private static bool IsOriginalDefinition(INamedTypeSymbol type, string fullMetadataName)
+ {
+ var original = type.OriginalDefinition;
+ var ns = original.ContainingNamespace?.ToDisplayString();
+ return string.Equals($"{ns}.{original.MetadataName}", fullMetadataName, StringComparison.Ordinal);
+ }
+
+ private static bool IsFuncReturningEnumerable(INamedTypeSymbol type)
+ {
+ return IsOriginalDefinition(type, "System.Func`1")
+ && type.TypeArguments.Length == 1
+ && type.TypeArguments[0] is INamedTypeSymbol returnType
+ && IsOriginalDefinition(returnType, "System.Collections.Generic.IEnumerable`1");
+ }
+
+ private static bool IsFuncReturningTaskOfEnumerable(INamedTypeSymbol type)
+ {
+ return IsOriginalDefinition(type, "System.Func`1")
+ && type.TypeArguments.Length == 1
+ && type.TypeArguments[0] is INamedTypeSymbol taskType
+ && IsOriginalDefinition(taskType, "System.Threading.Tasks.Task`1")
+ && taskType.TypeArguments.Length == 1
+ && taskType.TypeArguments[0] is INamedTypeSymbol enumerableType
+ && IsOriginalDefinition(enumerableType, "System.Collections.Generic.IEnumerable`1");
+ }
+
+ private static bool ImplementsAssertionSource(
+ INamedTypeSymbol type,
+ INamedTypeSymbol assertionSource,
+ out ITypeSymbol sourceTypeArg)
+ {
+ if (IsAssertionSourceInterface(type, assertionSource))
+ {
+ sourceTypeArg = type.TypeArguments[0];
+ return true;
+ }
+
+ foreach (var iface in type.AllInterfaces)
+ {
+ if (IsAssertionSourceInterface(iface, assertionSource))
+ {
+ sourceTypeArg = iface.TypeArguments[0];
+ return true;
+ }
+ }
+
+ sourceTypeArg = null!;
+ return false;
+ }
+
+ private static int TryGetOverloadResolutionPriority(ImmutableArray attrs)
+ {
+ foreach (var attr in attrs)
+ {
+ if (attr.AttributeClass?.Name == "OverloadResolutionPriorityAttribute"
+ && attr.ConstructorArguments.Length == 1
+ && attr.ConstructorArguments[0].Value is int priority)
+ {
+ return priority;
+ }
+ }
+
+ return 0;
+ }
+
+ private static void EmitShouldEntries(SourceProductionContext ctx, ShouldEntryData[] entries, HashSet emittedHints)
+ {
+ var sb = new StringBuilder();
+ sb.AppendLine("// ");
+ sb.AppendLine("#nullable enable");
+ sb.AppendLine();
+ sb.AppendLine("using System;");
+ sb.AppendLine("using System.Runtime.CompilerServices;");
+ sb.AppendLine("using TUnit.Assertions;");
+ sb.AppendLine("using TUnit.Assertions.Should.Core;");
+ sb.AppendLine();
+ sb.AppendLine("namespace TUnit.Assertions.Should;");
+ sb.AppendLine();
+ sb.AppendLine("public static partial class ShouldExtensions");
+ sb.AppendLine("{");
+
+ foreach (var entry in entries)
+ {
+ EmitShouldEntry(sb, entry);
+ }
+
+ sb.AppendLine("}");
+
+ var hint = "ShouldExtensions.Generated.g.cs";
+ var suffix = 0;
+ while (!emittedHints.Add(hint))
+ {
+ hint = $"ShouldExtensions.Generated_{++suffix}.g.cs";
+ }
+
+ ctx.AddSource(hint, sb.ToString());
+ }
+
+ private static void EmitShouldEntry(StringBuilder sb, ShouldEntryData entry)
+ {
+ var genericList = entry.MethodGenericParams.Length > 0
+ ? "<" + string.Join(", ", entry.MethodGenericParams.Select(p =>
+ p.DynamicallyAccessedMembersAttribute is null
+ ? p.Name
+ : $"{p.DynamicallyAccessedMembersAttribute} {p.Name}")) + ">"
+ : string.Empty;
+
+ var constraints = string.Join(" ", entry.MethodGenericParams
+ .Select(p => p.ConstraintClause)
+ .Where(c => c is not null));
+
+ sb.AppendLine();
+ if (entry.Priority != 0)
+ {
+ sb.AppendLine($" [global::System.Runtime.CompilerServices.OverloadResolutionPriority({entry.Priority})]");
+ }
+ if (!string.IsNullOrEmpty(entry.RequiresUnreferencedCodeMessage))
+ {
+ var escaped = entry.RequiresUnreferencedCodeMessage!.Replace("\"", "\\\"");
+ sb.AppendLine($" [global::System.Diagnostics.CodeAnalysis.RequiresUnreferencedCode(\"{escaped}\")]");
+ }
+ foreach (var code in entry.SuppressedTrimWarnings)
+ {
+ sb.AppendLine($" [global::System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage(\"Trimming\", \"{code}\", Justification = \"Forwarded from source method\")]");
+ }
+ foreach (var attr in entry.ForwardedAttributes)
+ {
+ sb.AppendLine($" {attr}");
+ }
+
+ sb.Append($" public static global::TUnit.Assertions.Should.Core.ShouldSource<{entry.SourceTypeArgDisplay}> Should{genericList}(this {entry.ReceiverTypeName} value");
+ foreach (var p in entry.Parameters)
+ {
+ sb.Append($", {p.TypeName} {p.Name}");
+ if (p.HasDefaultValue)
+ {
+ sb.Append(" = ").Append(p.DefaultValueLiteral);
+ }
+ }
+ sb.Append(')');
+
+ if (!string.IsNullOrEmpty(constraints))
+ {
+ sb.AppendLine();
+ sb.Append(" ").Append(constraints);
+ }
+
+ sb.AppendLine();
+ sb.AppendLine(" {");
+ sb.Append(" var source = global::TUnit.Assertions.Assert.That").Append(genericList).Append("(value");
+ foreach (var p in entry.Parameters)
+ {
+ sb.Append(", ").Append(p.Name);
+ }
+ sb.AppendLine(");");
+ sb.AppendLine($" var innerContext = ((global::TUnit.Assertions.Core.IAssertionSource<{entry.SourceTypeArgDisplay}>)source).Context;");
+ sb.AppendLine(" innerContext.ExpressionBuilder.Clear();");
+
+ var expressionParam = entry.Parameters.LastOrDefault(p => p.CallerArgumentExpressionTarget == "value");
+ if (expressionParam is null)
+ {
+ sb.AppendLine(" innerContext.ExpressionBuilder.Append(\"?.Should()\");");
+ }
+ else
+ {
+ sb.AppendLine($" innerContext.ExpressionBuilder.Append({expressionParam.Name} ?? \"?\").Append(\".Should()\");");
+ }
+
+ sb.AppendLine($" return new global::TUnit.Assertions.Should.Core.ShouldSource<{entry.SourceTypeArgDisplay}>(innerContext);");
+ sb.AppendLine(" }");
+ }
+
private sealed record GeneratorPayload(
EquatableArray Methods,
- EquatableArray Wrappers);
+ EquatableArray Wrappers,
+ EquatableArray Entries);
+
+ private sealed record ShouldEntryData(
+ string ReceiverTypeName,
+ string SourceTypeArgDisplay,
+ EquatableArray MethodGenericParams,
+ EquatableArray Parameters,
+ string SignatureKey,
+ int Priority,
+ string? RequiresUnreferencedCodeMessage,
+ EquatableArray SuppressedTrimWarnings,
+ EquatableArray ForwardedAttributes);
private sealed record WrapperData(
string ContainingNamespace,
diff --git a/TUnit.Assertions.Should.Tests/CollectionTests.cs b/TUnit.Assertions.Should.Tests/CollectionTests.cs
index 8c1e215950..e99e28abef 100644
--- a/TUnit.Assertions.Should.Tests/CollectionTests.cs
+++ b/TUnit.Assertions.Should.Tests/CollectionTests.cs
@@ -96,6 +96,39 @@ public async Task HaveCount()
await list.Should().HaveCount(3);
}
+ [Test]
+ public async Task Dictionary_ContainKeyWithValue()
+ {
+ IReadOnlyDictionary dict = new Dictionary
+ {
+ ["one"] = 1,
+ ["two"] = 2,
+ };
+
+ await dict.Should().ContainKeyWithValue("one", 1);
+ }
+
+ [Test]
+ public async Task HashSet_BeSupersetOf()
+ {
+ var set = new HashSet { "apple", "banana", "cherry" };
+ await set.Should().BeSupersetOf(["banana"]);
+ }
+
+ [Test]
+ public async Task HashSet_HaveCount()
+ {
+ var set = new HashSet { "apple", "banana", "cherry" };
+ await set.Should().HaveCount(3);
+ }
+
+ [Test]
+ public async Task Func_collection_HaveAtLeast()
+ {
+ Func func = () => [1, 2];
+ await func.Should().HaveAtLeast(2);
+ }
+
[Test]
public async Task Contain_predicate()
{
diff --git a/TUnit.Assertions.Should/Core/ShouldCollectionSource.cs b/TUnit.Assertions.Should/Core/ShouldCollectionSource.cs
index 717c1e41e9..744abe41de 100644
--- a/TUnit.Assertions.Should/Core/ShouldCollectionSource.cs
+++ b/TUnit.Assertions.Should/Core/ShouldCollectionSource.cs
@@ -1,6 +1,5 @@
using System.ComponentModel;
using System.Runtime.CompilerServices;
-using System.Text;
using TUnit.Assertions.Conditions;
using TUnit.Assertions.Core;
using TUnit.Assertions.Should.Attributes;
@@ -21,77 +20,16 @@ namespace TUnit.Assertions.Should.Core;
///
///
[ShouldGeneratePartial(typeof(CollectionAssertion<>))]
-public sealed partial class ShouldCollectionSource : IShouldSource>
+public sealed partial class ShouldCollectionSource : ShouldEnumerableSourceBase, TItem, ShouldCollectionSource>
{
- private string? _becauseMessage;
-
- public AssertionContext> Context { get; }
-
[EditorBrowsable(EditorBrowsableState.Never)]
public ShouldCollectionSource(IEnumerable? value, string? expression)
+ : base(new AssertionContext>(value, ShouldExpressionBuilder.Build(expression)))
{
- var sb = new StringBuilder((expression?.Length ?? 1) + 16);
- sb.Append(expression ?? "?").Append(".Should()");
- Context = new AssertionContext>(value, sb);
- }
-
- public ShouldCollectionSource Because(string message)
- {
- _becauseMessage = message.Trim();
- return this;
- }
-
- string? IShouldSource>.ConsumeBecauseMessage()
- => ConsumeBecauseMessage();
-
- private string? ConsumeBecauseMessage()
- {
- var message = _becauseMessage;
- _becauseMessage = null;
- return message;
- }
-
- private TAssertion ApplyBecause(TAssertion assertion)
- where TAssertion : Assertion>
- {
- var because = ConsumeBecauseMessage();
- if (because is not null)
- {
- assertion.Because(because);
- }
- return assertion;
- }
-
- // The next three methods can't be source-generated by the simple-factory rule: their target
- // assertion ctors take a separate `predicateDescription` string (e.g.
- // CollectionAllAssertion(ctx, predicate, predicateDescription)) which the source-side method
- // fills from the CAE expression value with a literal fallback — a one-method-param-to-two-
- // ctor-params shape the generator's filter rejects.
-
- public ShouldAssertion> All(
- Func predicate,
- [CallerArgumentExpression(nameof(predicate))] string? expression = null)
- {
- Context.ExpressionBuilder.Append(".All(").Append(expression).Append(')');
- var inner = ApplyBecause(new CollectionAllAssertion, TItem>(Context, predicate, expression ?? "predicate"));
- return new ShouldAssertion>(Context, inner);
- }
-
- public ShouldAssertion> Any(
- Func predicate,
- [CallerArgumentExpression(nameof(predicate))] string? expression = null)
- {
- Context.ExpressionBuilder.Append(".Any(").Append(expression).Append(')');
- var inner = ApplyBecause(new CollectionAnyAssertion, TItem>(Context, predicate, expression ?? "predicate"));
- return new ShouldAssertion>(Context, inner);
}
- public ShouldAssertion> HaveSingleItem(
- Func predicate,
- [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ internal ShouldCollectionSource(AssertionContext> context)
+ : base(context)
{
- Context.ExpressionBuilder.Append(".HaveSingleItem(").Append(expression).Append(')');
- var inner = ApplyBecause(new HasSingleItemPredicateAssertion, TItem>(Context, predicate, expression ?? "predicate"));
- return new ShouldAssertion>(Context, inner);
}
}
diff --git a/TUnit.Assertions.Should/Core/ShouldDelegateCollectionSource.cs b/TUnit.Assertions.Should/Core/ShouldDelegateCollectionSource.cs
new file mode 100644
index 0000000000..334a92d12e
--- /dev/null
+++ b/TUnit.Assertions.Should/Core/ShouldDelegateCollectionSource.cs
@@ -0,0 +1,127 @@
+using System.ComponentModel;
+using System.Runtime.CompilerServices;
+using TUnit.Assertions.Conditions;
+using TUnit.Assertions.Core;
+
+namespace TUnit.Assertions.Should.Core;
+
+public readonly struct ShouldDelegateCollectionSource : IShouldSource>
+{
+ private readonly string? _becauseMessage;
+
+ public AssertionContext> Context { get; }
+
+ [EditorBrowsable(EditorBrowsableState.Never)]
+ public ShouldDelegateCollectionSource(AssertionContext> context)
+ : this(context, becauseMessage: null)
+ {
+ }
+
+ private ShouldDelegateCollectionSource(AssertionContext> context, string? becauseMessage)
+ {
+ Context = context;
+ _becauseMessage = becauseMessage;
+ }
+
+ ///
+ /// Attaches a human-readable reason to the next assertion in the chain. Returns a NEW struct —
+ /// because is a readonly struct, the
+ /// result MUST be consumed inline (e.g. source.Because("...").HaveAtLeast(2)). Assigning
+ /// it to a variable and continuing on the original copy silently drops the message.
+ ///
+ public ShouldDelegateCollectionSource Because(string message)
+ => new(Context, message.Trim());
+
+ string? IShouldSource>.ConsumeBecauseMessage()
+ => _becauseMessage;
+
+ public ShouldAssertion> HaveAtLeast(
+ int minCount,
+ [CallerArgumentExpression(nameof(minCount))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append($".HaveAtLeast({expression})");
+ var inner = new CollectionHasAtLeastAssertion, TItem>(Context, minCount);
+ ApplyBecause(inner);
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> HaveAtMost(
+ int maxCount,
+ [CallerArgumentExpression(nameof(maxCount))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append($".HaveAtMost({expression})");
+ var inner = new CollectionHasAtMostAssertion, TItem>(Context, maxCount);
+ ApplyBecause(inner);
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> HaveCount(
+ int expectedCount,
+ [CallerArgumentExpression(nameof(expectedCount))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append($".HaveCount({expression})");
+ var inner = new CollectionCountAssertion, TItem>(Context, expectedCount);
+ ApplyBecause(inner);
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion Throw() where TException : Exception
+ {
+ Context.ExpressionBuilder.Append($".Throw<{DelegateExceptionTypeFormatter.FormatTypeName(typeof(TException))}>()");
+ var mapped = Context.MapException();
+ var inner = new ThrowsAssertion(mapped);
+ ApplyBecause(inner);
+ return new ShouldAssertion(mapped, inner);
+ }
+
+ public ShouldAssertion ThrowExactly() where TException : Exception
+ {
+ Context.ExpressionBuilder.Append($".ThrowExactly<{DelegateExceptionTypeFormatter.FormatTypeName(typeof(TException))}>()");
+ var mapped = Context.MapException();
+ var inner = new ThrowsExactlyAssertion(mapped);
+ ApplyBecause(inner);
+ return new ShouldAssertion(mapped, inner);
+ }
+
+ private void ApplyBecause(Assertion assertion)
+ {
+ if (_becauseMessage is not null)
+ {
+ assertion.Because(_becauseMessage);
+ }
+ }
+
+ internal static AssertionContext> CreateContext(Func?> func, string? expression)
+ {
+ var expressionBuilder = ShouldExpressionBuilder.Build(expression);
+ var evaluationContext = new EvaluationContext>(() =>
+ {
+ try
+ {
+ return Task.FromResult<(IEnumerable?, Exception?)>((func(), null));
+ }
+ catch (Exception ex)
+ {
+ return Task.FromResult<(IEnumerable?, Exception?)>((default, ex));
+ }
+ });
+ return new AssertionContext>(evaluationContext, expressionBuilder);
+ }
+
+ internal static AssertionContext> CreateContext(Func?>> func, string? expression)
+ {
+ var expressionBuilder = ShouldExpressionBuilder.Build(expression);
+ var evaluationContext = new EvaluationContext>(async () =>
+ {
+ try
+ {
+ return (await func().ConfigureAwait(false), null);
+ }
+ catch (Exception ex)
+ {
+ return (default, ex);
+ }
+ });
+ return new AssertionContext>(evaluationContext, expressionBuilder);
+ }
+}
diff --git a/TUnit.Assertions.Should/Core/ShouldDelegateSource.cs b/TUnit.Assertions.Should/Core/ShouldDelegateSource.cs
index abe2e652a5..0e58efac6e 100644
--- a/TUnit.Assertions.Should/Core/ShouldDelegateSource.cs
+++ b/TUnit.Assertions.Should/Core/ShouldDelegateSource.cs
@@ -5,6 +5,36 @@
namespace TUnit.Assertions.Should.Core;
+internal static class DelegateExceptionTypeFormatter
+{
+ ///
+ /// Renders a type's display name for the assertion's expression-builder string. Strips the
+ /// backtick-arity suffix and recurses into generic arguments so that
+ /// typeof(MyException<int>) appears as MyException<Int32> in
+ /// failure messages rather than the raw MyException`1 that
+ /// returns. Note: this runs at runtime (no Roslyn) so primitive aliases come through as their
+ /// CLR names (Int32, not int); the source-generator's emit path uses Roslyn's
+ /// display format and produces int. The asymmetry is acceptable for failure messages
+ /// — exception types are rarely generic and almost never primitive — but is worth noting.
+ ///
+ internal static string FormatTypeName(System.Type t)
+ {
+ if (!t.IsGenericType)
+ {
+ return t.Name;
+ }
+
+ var name = t.Name;
+ var tickIndex = name.IndexOf('`');
+ if (tickIndex > 0)
+ {
+ name = name.Substring(0, tickIndex);
+ }
+
+ return $"{name}<{string.Join(", ", t.GenericTypeArguments.Select(FormatTypeName))}>";
+ }
+}
+
///
/// Should-flavored entry wrapper for delegates and async functions. Surfaces
/// Throw<TException> / ThrowExactly<TException> instance methods
@@ -27,6 +57,12 @@ private ShouldDelegateSource(AssertionContext context, string? becauseMessage
_becauseMessage = becauseMessage;
}
+ ///
+ /// Attaches a human-readable reason to the next assertion in the chain. Returns a NEW struct —
+ /// because is a readonly struct, the result MUST
+ /// be consumed inline (e.g. source.Because("...").Throw<E>()). Assigning it to a
+ /// variable and continuing on the original copy silently drops the message.
+ ///
public ShouldDelegateSource Because(string message)
=> new(Context, message.Trim());
@@ -38,7 +74,7 @@ public ShouldDelegateSource Because(string message)
///
public ShouldAssertion Throw() where TException : Exception
{
- Context.ExpressionBuilder.Append($".Throw<{FormatTypeName(typeof(TException))}>()");
+ Context.ExpressionBuilder.Append($".Throw<{DelegateExceptionTypeFormatter.FormatTypeName(typeof(TException))}>()");
var mapped = Context.MapException();
var inner = new ThrowsAssertion(mapped);
ApplyBecause(inner);
@@ -50,7 +86,7 @@ public ShouldAssertion Throw() where TException : Except
///
public ShouldAssertion ThrowExactly() where TException : Exception
{
- Context.ExpressionBuilder.Append($".ThrowExactly<{FormatTypeName(typeof(TException))}>()");
+ Context.ExpressionBuilder.Append($".ThrowExactly<{DelegateExceptionTypeFormatter.FormatTypeName(typeof(TException))}>()");
var mapped = Context.MapException();
var inner = new ThrowsExactlyAssertion(mapped);
ApplyBecause(inner);
@@ -64,31 +100,4 @@ private void ApplyBecause(Assertion assertion)
assertion.Because(_becauseMessage);
}
}
-
- ///
- /// Renders a type's display name for the assertion's expression-builder string. Strips the
- /// backtick-arity suffix and recurses into generic arguments so that
- /// typeof(MyException<int>) appears as MyException<Int32> in
- /// failure messages rather than the raw MyException`1 that
- /// returns. Note: this runs at runtime (no Roslyn) so primitive aliases come through as their
- /// CLR names (Int32, not int); the source-generator's emit path uses Roslyn's
- /// display format and produces int. The asymmetry is acceptable for failure messages
- /// — exception types are rarely generic and almost never primitive — but is worth noting.
- ///
- private static string FormatTypeName(System.Type t)
- {
- if (!t.IsGenericType)
- {
- return t.Name;
- }
-
- var name = t.Name;
- var tickIndex = name.IndexOf('`');
- if (tickIndex > 0)
- {
- name = name.Substring(0, tickIndex);
- }
-
- return $"{name}<{string.Join(", ", t.GenericTypeArguments.Select(FormatTypeName))}>";
- }
}
diff --git a/TUnit.Assertions.Should/Core/ShouldDictionarySource.cs b/TUnit.Assertions.Should/Core/ShouldDictionarySource.cs
new file mode 100644
index 0000000000..c058bc04b0
--- /dev/null
+++ b/TUnit.Assertions.Should/Core/ShouldDictionarySource.cs
@@ -0,0 +1,231 @@
+using System.Collections.Generic;
+using System.ComponentModel;
+using System.Runtime.CompilerServices;
+using TUnit.Assertions.Conditions;
+using TUnit.Assertions.Core;
+using TUnit.Assertions.Should.Attributes;
+using TUnit.Assertions.Sources;
+
+namespace TUnit.Assertions.Should.Core;
+
+[ShouldGeneratePartial(typeof(DictionaryAssertion<,>))]
+public sealed partial class ShouldDictionarySource
+ : ShouldEnumerableSourceBase, KeyValuePair, ShouldDictionarySource>
+ where TKey : notnull
+{
+ [EditorBrowsable(EditorBrowsableState.Never)]
+ public ShouldDictionarySource(IReadOnlyDictionary? value, string? expression)
+ : base(new AssertionContext>(value!, ShouldExpressionBuilder.Build(expression)))
+ {
+ }
+
+ internal ShouldDictionarySource(AssertionContext> context)
+ : base(context)
+ {
+ }
+
+ public ShouldAssertion> ContainKey(
+ TKey expectedKey,
+ [CallerArgumentExpression(nameof(expectedKey))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".ContainKey(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryContainsKeyAssertion, TKey, TValue>(Context, expectedKey));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> ContainKey(
+ TKey expectedKey,
+ IEqualityComparer? comparer,
+ [CallerArgumentExpression(nameof(expectedKey))] string? keyExpression = null,
+ [CallerArgumentExpression(nameof(comparer))] string? comparerExpression = null)
+ {
+ Context.ExpressionBuilder.Append($".ContainKey({keyExpression}, {comparerExpression})");
+ var inner = ApplyBecause(new DictionaryContainsKeyAssertion, TKey, TValue>(Context, expectedKey, comparer));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> NotContainKey(
+ TKey expectedKey,
+ [CallerArgumentExpression(nameof(expectedKey))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".NotContainKey(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryDoesNotContainKeyAssertion, TKey, TValue>(Context, expectedKey));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> ContainValue(
+ TValue expectedValue,
+ [CallerArgumentExpression(nameof(expectedValue))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".ContainValue(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryContainsValueAssertion, TKey, TValue>(Context, expectedValue));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> NotContainValue(
+ TValue expectedValue,
+ [CallerArgumentExpression(nameof(expectedValue))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".NotContainValue(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryDoesNotContainValueAssertion, TKey, TValue>(Context, expectedValue));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> ContainKeyWithValue(
+ TKey expectedKey,
+ TValue expectedValue,
+ [CallerArgumentExpression(nameof(expectedKey))] string? keyExpression = null,
+ [CallerArgumentExpression(nameof(expectedValue))] string? valueExpression = null)
+ {
+ Context.ExpressionBuilder.Append($".ContainKeyWithValue({keyExpression}, {valueExpression})");
+ var inner = ApplyBecause(new DictionaryContainsKeyWithValueAssertion, TKey, TValue>(Context, expectedKey, expectedValue));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> AllKeys(
+ Func predicate,
+ [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".AllKeys(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryAllKeysAssertion, TKey, TValue>(Context, predicate));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> AllValues(
+ Func predicate,
+ [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".AllValues(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryAllValuesAssertion, TKey, TValue>(Context, predicate));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> AnyKey(
+ Func predicate,
+ [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".AnyKey(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryAnyKeyAssertion, TKey, TValue>(Context, predicate));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> AnyValue(
+ Func predicate,
+ [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".AnyValue(").Append(expression).Append(')');
+ var inner = ApplyBecause(new DictionaryAnyValueAssertion, TKey, TValue>(Context, predicate));
+ return new ShouldAssertion>(Context, inner);
+ }
+}
+
+[ShouldGeneratePartial(typeof(MutableDictionaryAssertion<,>))]
+public sealed partial class ShouldMutableDictionarySource
+ : ShouldEnumerableSourceBase, KeyValuePair, ShouldMutableDictionarySource>
+ where TKey : notnull
+{
+ [EditorBrowsable(EditorBrowsableState.Never)]
+ public ShouldMutableDictionarySource(IDictionary? value, string? expression)
+ : base(new AssertionContext>(value!, ShouldExpressionBuilder.Build(expression)))
+ {
+ }
+
+ internal ShouldMutableDictionarySource(AssertionContext> context)
+ : base(context)
+ {
+ }
+
+ public ShouldAssertion> ContainKey(
+ TKey expectedKey,
+ [CallerArgumentExpression(nameof(expectedKey))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".ContainKey(").Append(expression).Append(')');
+ var inner = ApplyBecause(new MutableDictionaryContainsKeyAssertion, TKey, TValue>(Context, expectedKey));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> ContainKey(
+ TKey expectedKey,
+ IEqualityComparer? comparer,
+ [CallerArgumentExpression(nameof(expectedKey))] string? keyExpression = null,
+ [CallerArgumentExpression(nameof(comparer))] string? comparerExpression = null)
+ {
+ Context.ExpressionBuilder.Append($".ContainKey({keyExpression}, {comparerExpression})");
+ var inner = ApplyBecause(new MutableDictionaryContainsKeyAssertion, TKey, TValue>(Context, expectedKey, comparer));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> NotContainKey(
+ TKey expectedKey,
+ [CallerArgumentExpression(nameof(expectedKey))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".NotContainKey(").Append(expression).Append(')');
+ var inner = ApplyBecause(new MutableDictionaryDoesNotContainKeyAssertion, TKey, TValue>(Context, expectedKey));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> ContainValue(
+ TValue expectedValue,
+ [CallerArgumentExpression(nameof(expectedValue))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".ContainValue(").Append(expression).Append(')');
+ var inner = ApplyBecause(new MutableDictionaryContainsValueAssertion, TKey, TValue>(Context, expectedValue));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> NotContainValue(
+ TValue expectedValue,
+ [CallerArgumentExpression(nameof(expectedValue))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".NotContainValue(").Append(expression).Append(')');
+ var inner = ApplyBecause(new MutableDictionaryDoesNotContainValueAssertion, TKey, TValue>(Context, expectedValue));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> ContainKeyWithValue(
+ TKey expectedKey,
+ TValue expectedValue,
+ [CallerArgumentExpression(nameof(expectedKey))] string? keyExpression = null,
+ [CallerArgumentExpression(nameof(expectedValue))] string? valueExpression = null)
+ {
+ Context.ExpressionBuilder.Append($".ContainKeyWithValue({keyExpression}, {valueExpression})");
+ var inner = ApplyBecause(new MutableDictionaryContainsKeyWithValueAssertion, TKey, TValue>(Context, expectedKey, expectedValue));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> AllKeys(
+ Func predicate,
+ [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".AllKeys(").Append(expression).Append(')');
+ var inner = ApplyBecause(new MutableDictionaryAllKeysAssertion, TKey, TValue>(Context, predicate));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> AllValues(
+ Func predicate,
+ [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".AllValues(").Append(expression).Append(')');
+ var inner = ApplyBecause(new MutableDictionaryAllValuesAssertion, TKey, TValue>(Context, predicate));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> AnyKey(
+ Func predicate,
+ [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".AnyKey(").Append(expression).Append(')');
+ var inner = ApplyBecause(new MutableDictionaryAnyKeyAssertion, TKey, TValue>(Context, predicate));
+ return new ShouldAssertion>(Context, inner);
+ }
+
+ public ShouldAssertion> AnyValue(
+ Func predicate,
+ [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".AnyValue(").Append(expression).Append(')');
+ var inner = ApplyBecause(new MutableDictionaryAnyValueAssertion, TKey, TValue>(Context, predicate));
+ return new ShouldAssertion>(Context, inner);
+ }
+}
diff --git a/TUnit.Assertions.Should/Core/ShouldSetSource.cs b/TUnit.Assertions.Should/Core/ShouldSetSource.cs
new file mode 100644
index 0000000000..71889dea5e
--- /dev/null
+++ b/TUnit.Assertions.Should/Core/ShouldSetSource.cs
@@ -0,0 +1,62 @@
+using System.Collections.Generic;
+using System.ComponentModel;
+using TUnit.Assertions.Adapters;
+using TUnit.Assertions.Abstractions;
+using TUnit.Assertions.Core;
+using TUnit.Assertions.Should.Attributes;
+using TUnit.Assertions.Sources;
+
+namespace TUnit.Assertions.Should.Core;
+
+[ShouldGeneratePartial(typeof(SetAssertion<>))]
+public sealed partial class ShouldSetSource : ShouldSetSourceBase, TItem, ShouldSetSource>
+{
+ [EditorBrowsable(EditorBrowsableState.Never)]
+ public ShouldSetSource(ISet? value, string? expression)
+ : base(new AssertionContext>(value!, ShouldExpressionBuilder.Build(expression)))
+ {
+ }
+
+ internal ShouldSetSource(AssertionContext> context)
+ : base(context)
+ {
+ }
+
+ protected override ISetAdapter CreateSetAdapter(ISet value) => new SetAdapter(value);
+}
+
+#if NET5_0_OR_GREATER
+[ShouldGeneratePartial(typeof(ReadOnlySetAssertion<>))]
+public sealed partial class ShouldReadOnlySetSource : ShouldSetSourceBase, TItem, ShouldReadOnlySetSource>
+{
+ [EditorBrowsable(EditorBrowsableState.Never)]
+ public ShouldReadOnlySetSource(IReadOnlySet? value, string? expression)
+ : base(new AssertionContext>(value!, ShouldExpressionBuilder.Build(expression)))
+ {
+ }
+
+ internal ShouldReadOnlySetSource(AssertionContext> context)
+ : base(context)
+ {
+ }
+
+ protected override ISetAdapter CreateSetAdapter(IReadOnlySet value) => new ReadOnlySetAdapter(value);
+}
+#endif
+
+[ShouldGeneratePartial(typeof(HashSetAssertion<>))]
+public sealed partial class ShouldHashSetSource : ShouldSetSourceBase, TItem, ShouldHashSetSource>
+{
+ [EditorBrowsable(EditorBrowsableState.Never)]
+ public ShouldHashSetSource(HashSet? value, string? expression)
+ : base(new AssertionContext>(value!, ShouldExpressionBuilder.Build(expression)))
+ {
+ }
+
+ internal ShouldHashSetSource(AssertionContext> context)
+ : base(context)
+ {
+ }
+
+ protected override ISetAdapter CreateSetAdapter(HashSet value) => new SetAdapter(value);
+}
diff --git a/TUnit.Assertions.Should/Core/ShouldSourceBase.cs b/TUnit.Assertions.Should/Core/ShouldSourceBase.cs
new file mode 100644
index 0000000000..2725b2c45b
--- /dev/null
+++ b/TUnit.Assertions.Should/Core/ShouldSourceBase.cs
@@ -0,0 +1,177 @@
+using System.Collections.Generic;
+using System.Runtime.CompilerServices;
+using System.Text;
+using TUnit.Assertions.Abstractions;
+using TUnit.Assertions.Conditions;
+using TUnit.Assertions.Core;
+
+namespace TUnit.Assertions.Should.Core;
+
+internal static class ShouldExpressionBuilder
+{
+ internal static StringBuilder Build(string? expression)
+ {
+ var sb = new StringBuilder((expression?.Length ?? 1) + 16);
+ sb.Append(expression ?? "?").Append(".Should()");
+ return sb;
+ }
+}
+
+public abstract class ShouldSourceBase : IShouldSource
+ where TSelf : ShouldSourceBase
+{
+ private string? _becauseMessage;
+
+ protected ShouldSourceBase(AssertionContext context)
+ {
+ Context = context;
+ }
+
+ public AssertionContext Context { get; }
+
+ public TSelf Because(string message)
+ {
+ _becauseMessage = message.Trim();
+ return (TSelf)this;
+ }
+
+ string? IShouldSource.ConsumeBecauseMessage()
+ => ConsumeBecauseMessage();
+
+ protected string? ConsumeBecauseMessage()
+ {
+ var message = _becauseMessage;
+ _becauseMessage = null;
+ return message;
+ }
+
+ protected void ResetShouldExpression(string? expression)
+ {
+ Context.ExpressionBuilder.Clear();
+ Context.ExpressionBuilder.Append(expression ?? "?").Append(".Should()");
+ }
+}
+
+public abstract class ShouldEnumerableSourceBase : ShouldSourceBase
+ where TCollection : IEnumerable
+ where TSelf : ShouldEnumerableSourceBase
+{
+ protected ShouldEnumerableSourceBase(AssertionContext context)
+ : base(context)
+ {
+ }
+
+ protected TAssertion ApplyBecause(TAssertion assertion)
+ where TAssertion : Assertion
+ {
+ var because = ConsumeBecauseMessage();
+ if (because is not null)
+ {
+ assertion.Because(because);
+ }
+
+ return assertion;
+ }
+
+ public ShouldAssertion All(
+ Func predicate,
+ [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".All(").Append(expression).Append(')');
+ var inner = ApplyBecause(new CollectionAllAssertion(Context, predicate, expression ?? "predicate"));
+ return new ShouldAssertion(Context, inner);
+ }
+
+ public ShouldAssertion Any(
+ Func predicate,
+ [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".Any(").Append(expression).Append(')');
+ var inner = ApplyBecause(new CollectionAnyAssertion(Context, predicate, expression ?? "predicate"));
+ return new ShouldAssertion(Context, inner);
+ }
+
+ public ShouldAssertion HaveSingleItem(
+ Func predicate,
+ [CallerArgumentExpression(nameof(predicate))] string? expression = null)
+ {
+ Context.ExpressionBuilder.Append(".HaveSingleItem(").Append(expression).Append(')');
+ var inner = ApplyBecause(new HasSingleItemPredicateAssertion(Context, predicate, expression ?? "predicate"));
+ return new ShouldAssertion