diff --git a/TUnit.Assertions.Analyzers.CodeFixers.Tests/CollectionIsEqualToCodeFixProviderTests.cs b/TUnit.Assertions.Analyzers.CodeFixers.Tests/CollectionIsEqualToCodeFixProviderTests.cs new file mode 100644 index 0000000000..7dc6a1c067 --- /dev/null +++ b/TUnit.Assertions.Analyzers.CodeFixers.Tests/CollectionIsEqualToCodeFixProviderTests.cs @@ -0,0 +1,138 @@ +using Verifier = TUnit.Assertions.Analyzers.CodeFixers.Tests.Verifiers.CSharpCodeFixVerifier< + TUnit.Assertions.Analyzers.CollectionIsEqualToAnalyzer, + TUnit.Assertions.Analyzers.CodeFixers.CollectionIsEqualToCodeFixProvider>; + +namespace TUnit.Assertions.Analyzers.CodeFixers.Tests; + +public class CollectionIsEqualToCodeFixProviderTests +{ + [Test] + public async Task Rewrites_IsEqualTo_To_IsEquivalentTo() + { + await Verifier + .VerifyCodeFixAsync( + """ + using System.Collections.Generic; + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + using TUnit.Core; + + public class MyClass + { + [Test] + public async Task Test() + { + var a = new List { 1, 2 }; + var b = new List { 1, 2 }; + await Assert.That(a).{|#0:IsEqualTo(b)|}; + } + } + """, + Verifier.Diagnostic(Rules.CollectionIsEqualToUsesReferenceEquality) + .WithLocation(0), + """ + using System.Collections.Generic; + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + using TUnit.Core; + + public class MyClass + { + [Test] + public async Task Test() + { + var a = new List { 1, 2 }; + var b = new List { 1, 2 }; + await Assert.That(a).IsEquivalentTo(b); + } + } + """ + ); + } + + [Test] + public async Task Fix_Preserves_Chained_Calls() + { + await Verifier.VerifyCodeFixAsync( + """ + using System.Collections.Generic; + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + using TUnit.Core; + + public class MyClass + { + [Test] + public async Task Test() + { + var a = new List { 1, 2 }; + var b = new List { 1, 2 }; + await Assert.That(a).{|#0:IsEqualTo(b)|}.And.IsNotNull(); + } + } + """, + Verifier.Diagnostic(Rules.CollectionIsEqualToUsesReferenceEquality).WithLocation(0), + """ + using System.Collections.Generic; + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + using TUnit.Core; + + public class MyClass + { + [Test] + public async Task Test() + { + var a = new List { 1, 2 }; + var b = new List { 1, 2 }; + await Assert.That(a).IsEquivalentTo(b).And.IsNotNull(); + } + } + """); + } + + [Test] + public async Task Fix_Works_On_Arrays() + { + await Verifier.VerifyCodeFixAsync( + """ + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + using TUnit.Core; + + public class MyClass + { + [Test] + public async Task Test() + { + int[] a = { 1 }; + int[] b = { 1 }; + await Assert.That(a).{|#0:IsEqualTo(b)|}; + } + } + """, + Verifier.Diagnostic(Rules.CollectionIsEqualToUsesReferenceEquality).WithLocation(0), + """ + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + using TUnit.Core; + + public class MyClass + { + [Test] + public async Task Test() + { + int[] a = { 1 }; + int[] b = { 1 }; + await Assert.That(a).IsEquivalentTo(b); + } + } + """); + } +} diff --git a/TUnit.Assertions.Analyzers.CodeFixers/CollectionIsEqualToCodeFixProvider.cs b/TUnit.Assertions.Analyzers.CodeFixers/CollectionIsEqualToCodeFixProvider.cs new file mode 100644 index 0000000000..08f29ca1ba --- /dev/null +++ b/TUnit.Assertions.Analyzers.CodeFixers/CollectionIsEqualToCodeFixProvider.cs @@ -0,0 +1,59 @@ +using System.Collections.Immutable; +using System.Composition; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CodeActions; +using Microsoft.CodeAnalysis.CodeFixes; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace TUnit.Assertions.Analyzers.CodeFixers; + +[ExportCodeFixProvider(LanguageNames.CSharp, Name = nameof(CollectionIsEqualToCodeFixProvider)), Shared] +public class CollectionIsEqualToCodeFixProvider : CodeFixProvider +{ + public sealed override ImmutableArray FixableDiagnosticIds { get; } = + ImmutableArray.Create(Rules.CollectionIsEqualToUsesReferenceEquality.Id); + + public override FixAllProvider GetFixAllProvider() => WellKnownFixAllProviders.BatchFixer; + + public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context) + { + var root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false); + if (root is null) + { + return; + } + + foreach (var diagnostic in context.Diagnostics) + { + // Analyzer reports a span covering `IsEqualTo(...)`; FindNode returns the enclosing invocation. + if (root.FindNode(diagnostic.Location.SourceSpan, getInnermostNodeForTie: true) + is not InvocationExpressionSyntax { Expression: MemberAccessExpressionSyntax { Name: IdentifierNameSyntax identifier } }) + { + continue; + } + + context.RegisterCodeFix( + CodeAction.Create( + title: Resources.TUnitAssertions0016CodeFixTitle, + createChangedDocument: c => ReplaceAsync(context.Document, identifier, c), + equivalenceKey: nameof(Resources.TUnitAssertions0016CodeFixTitle)), + diagnostic); + } + } + + private static async Task ReplaceAsync(Document document, IdentifierNameSyntax identifier, CancellationToken cancellationToken) + { + var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false); + if (root is null) + { + return document; + } + + var replacement = SyntaxFactory + .IdentifierName("IsEquivalentTo") + .WithTriviaFrom(identifier); + + return document.WithSyntaxRoot(root.ReplaceNode(identifier, replacement)); + } +} diff --git a/TUnit.Assertions.Analyzers.Tests/CollectionIsEqualToAnalyzerTests.cs b/TUnit.Assertions.Analyzers.Tests/CollectionIsEqualToAnalyzerTests.cs new file mode 100644 index 0000000000..5ee2399049 --- /dev/null +++ b/TUnit.Assertions.Analyzers.Tests/CollectionIsEqualToAnalyzerTests.cs @@ -0,0 +1,226 @@ +using Verifier = TUnit.Assertions.Analyzers.Tests.Verifiers.CSharpAnalyzerVerifier; + +namespace TUnit.Assertions.Analyzers.Tests; + +public class CollectionIsEqualToAnalyzerTests +{ + [Test] + public async Task List_IsEqualTo_Raises_Info() + { + await Verifier + .VerifyAnalyzerAsync( + """ + using System.Collections.Generic; + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + using TUnit.Core; + + public class MyClass + { + [Test] + public async Task Test() + { + var a = new List { 1, 2, 3 }; + var b = new List { 1, 2, 3 }; + await Assert.That(a).{|#0:IsEqualTo(b)|}; + } + } + """, + Verifier.Diagnostic(Rules.CollectionIsEqualToUsesReferenceEquality) + .WithLocation(0) + ); + } + + [Test] + public async Task String_Not_Flagged() + { + await Verifier + .VerifyAnalyzerAsync( + """ + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + using TUnit.Core; + + public class MyClass + { + [Test] + public async Task Test() + { + await Assert.That("abc").IsEqualTo("abc"); + } + } + """ + ); + } + + [Test] + public async Task Int_Not_Flagged() + { + await Verifier + .VerifyAnalyzerAsync( + """ + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + using TUnit.Core; + + public class MyClass + { + [Test] + public async Task Test() => await Assert.That(1).IsEqualTo(1); + } + """ + ); + } + + [Test] + public async Task Array_IsEqualTo_Raises_Info() + { + await Verifier + .VerifyAnalyzerAsync( + """ + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + using TUnit.Core; + + public class MyClass + { + [Test] + public async Task Test() + { + int[] a = { 1, 2 }; + int[] b = { 1, 2 }; + await Assert.That(a).{|#0:IsEqualTo(b)|}; + } + } + """, + Verifier.Diagnostic(Rules.CollectionIsEqualToUsesReferenceEquality) + .WithLocation(0) + ); + } + + [Test] + public async Task Count_IsEqualTo_Not_Flagged() + { + await Verifier + .VerifyAnalyzerAsync( + """ + using System.Collections.Generic; + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + using TUnit.Core; + + public class MyClass + { + [Test] + public async Task Test() + { + var list = new List { 1, 2 }; + await Assert.That(list).Count().IsEqualTo(2); + } + } + """ + ); + } + + [Test] + public async Task CustomEnumerable_With_EqualsOverride_Not_Flagged() + { + await Verifier + .VerifyAnalyzerAsync( + """ + using System.Collections; + using System.Collections.Generic; + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + using TUnit.Core; + + public class MyBag : IEnumerable + { + public IEnumerator GetEnumerator() => new List().GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + public override bool Equals(object? obj) => obj is MyBag; + public override int GetHashCode() => 0; + } + + public class MyClass + { + [Test] + public async Task Test() + { + await Assert.That(new MyBag()).IsEqualTo(new MyBag()); + } + } + """ + ); + } + + [Test] + public async Task CustomEnumerable_With_IEquatable_Not_Flagged() + { + await Verifier + .VerifyAnalyzerAsync( + """ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + using TUnit.Core; + + public class MyBag : IEnumerable, IEquatable + { + public IEnumerator GetEnumerator() => new List().GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + public bool Equals(MyBag? other) => other is not null; + } + + public class MyClass + { + [Test] + public async Task Test() + { + await Assert.That(new MyBag()).IsEqualTo(new MyBag()); + } + } + """ + ); + } + + [Test] + public async Task Record_Collection_Not_Flagged() + { + await Verifier + .VerifyAnalyzerAsync( + """ + using System.Collections; + using System.Collections.Generic; + using System.Threading.Tasks; + using TUnit.Assertions; + using TUnit.Assertions.Extensions; + using TUnit.Core; + + public record MyRecordBag(int X) : IEnumerable + { + public IEnumerator GetEnumerator() => new List { X }.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + + public class MyClass + { + [Test] + public async Task Test() + { + await Assert.That(new MyRecordBag(1)).IsEqualTo(new MyRecordBag(1)); + } + } + """ + ); + } +} diff --git a/TUnit.Assertions.Analyzers/AnalyzerReleases.Unshipped.md b/TUnit.Assertions.Analyzers/AnalyzerReleases.Unshipped.md index ba3010ba8b..f30b25bebb 100644 --- a/TUnit.Assertions.Analyzers/AnalyzerReleases.Unshipped.md +++ b/TUnit.Assertions.Analyzers/AnalyzerReleases.Unshipped.md @@ -4,3 +4,4 @@ Rule ID | Category | Severity | Notes --------|----------|----------|------- TUnitAssertions0014 | Usage | Warning | Prefer IsNull() over IsEqualTo(null) TUnitAssertions0015 | Usage | Warning | Prefer IsTrue()/IsFalse() over IsEqualTo(true/false) +TUnitAssertions0016 | Usage | Info | Collection IsEqualTo compares by reference - use IsEquivalentTo diff --git a/TUnit.Assertions.Analyzers/CollectionIsEqualToAnalyzer.cs b/TUnit.Assertions.Analyzers/CollectionIsEqualToAnalyzer.cs new file mode 100644 index 0000000000..f0d4a6e267 --- /dev/null +++ b/TUnit.Assertions.Analyzers/CollectionIsEqualToAnalyzer.cs @@ -0,0 +1,208 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.Operations; +using Microsoft.CodeAnalysis.Text; + +namespace TUnit.Assertions.Analyzers; + +/// +/// Detects `.IsEqualTo(...)` on collection assertion sources. Because collection +/// types don't override Equals, this uses reference equality. Users almost always +/// want content equivalence via `.IsEquivalentTo(...)`. +/// +[DiagnosticAnalyzer(LanguageNames.CSharp)] +public class CollectionIsEqualToAnalyzer : ConcurrentDiagnosticAnalyzer +{ + public override ImmutableArray SupportedDiagnostics { get; } = + ImmutableArray.Create(Rules.CollectionIsEqualToUsesReferenceEquality); + + public override void InitializeInternal(AnalysisContext context) + { + context.RegisterCompilationStartAction(compilationStart => + { + var ienumerable = compilationStart.Compilation.GetTypeByMetadataName("System.Collections.IEnumerable"); + if (ienumerable is null) + { + return; + } + + compilationStart.RegisterOperationAction( + ctx => AnalyzeOperation(ctx, ienumerable), + OperationKind.Invocation); + }); + } + + private static void AnalyzeOperation(OperationAnalysisContext context, INamedTypeSymbol ienumerable) + { + var invocation = (IInvocationOperation)context.Operation; + var method = invocation.TargetMethod; + + if (method.Name != "IsEqualTo") + { + return; + } + + // Only the generic EqualsAssertion-returning overload uses default (reference) equality. + // Specialized overloads (CollectionCountEqualsAssertion, DateTimeEqualsAssertion, ...) are fine. + if (method.ReturnType is not INamedTypeSymbol { Name: "EqualsAssertion" }) + { + return; + } + + var containingNamespace = method.ContainingType?.ContainingNamespace?.ToDisplayString(); + if (containingNamespace is null || !containingNamespace.StartsWith("TUnit.Assertions")) + { + return; + } + + // For extension methods, Instance.Type is null; the `this` argument is Arguments[0] + // even though method.Parameters excludes it. + var sourceParamType = method.IsExtensionMethod + ? (invocation.Arguments.Length > 0 ? invocation.Arguments[0].Value.Type : null) + : invocation.Instance?.Type ?? method.ReceiverType; + + if (sourceParamType is not INamedTypeSymbol sourceType) + { + return; + } + + var assertedType = ExtractAssertionSourceTypeArgument(sourceType); + if (assertedType is null) + { + return; + } + + if (!IsCollectionWithoutStructuralEquality(assertedType, ienumerable)) + { + return; + } + + var syntax = invocation.Syntax; + var reportLocation = syntax is InvocationExpressionSyntax + { + Expression: MemberAccessExpressionSyntax memberAccess, + } invocationSyntax + ? Location.Create( + syntax.SyntaxTree, + TextSpan.FromBounds(memberAccess.Name.SpanStart, invocationSyntax.Span.End)) + : syntax.GetLocation(); + + context.ReportDiagnostic( + Diagnostic.Create(Rules.CollectionIsEqualToUsesReferenceEquality, reportLocation) + ); + } + + private static ITypeSymbol? ExtractAssertionSourceTypeArgument(INamedTypeSymbol sourceType) + { + if (TryGetAssertionSourceArg(sourceType, out var arg)) + { + return arg; + } + + foreach (var iface in sourceType.AllInterfaces) + { + if (TryGetAssertionSourceArg(iface, out arg)) + { + return arg; + } + } + + return null; + } + + private static bool TryGetAssertionSourceArg(INamedTypeSymbol type, out ITypeSymbol? arg) + { + if (type.Name == "IAssertionSource" + && type.ContainingNamespace?.ToDisplayString() == "TUnit.Assertions.Core" + && type.TypeArguments.Length == 1) + { + arg = type.TypeArguments[0]; + return true; + } + + arg = null; + return false; + } + + private static bool IsCollectionWithoutStructuralEquality(ITypeSymbol type, INamedTypeSymbol ienumerable) + { + if (type.SpecialType == SpecialType.System_String) + { + return false; + } + + var unconstructedName = (type as INamedTypeSymbol)?.ConstructedFrom?.ToDisplayString(); + if (unconstructedName is "System.Memory" + or "System.ReadOnlyMemory" + or "System.Span" + or "System.ReadOnlySpan") + { + return false; + } + + if (!ImplementsIEnumerable(type, ienumerable)) + { + return false; + } + + // Records synthesize an Equals(object) override; custom collections may too. + // EqualityComparer.Default also prefers IEquatable.Equals when implemented. + // In either case IsEqualTo is semantically correct and should not be flagged. + return !OverridesObjectEquals(type) && !ImplementsIEquatableOfSelf(type); + } + + private static bool ImplementsIEnumerable(ITypeSymbol type, INamedTypeSymbol ienumerable) + { + if (SymbolEqualityComparer.Default.Equals(type, ienumerable)) + { + return true; + } + + foreach (var iface in type.AllInterfaces) + { + if (SymbolEqualityComparer.Default.Equals(iface, ienumerable)) + { + return true; + } + } + + return false; + } + + private static bool ImplementsIEquatableOfSelf(ITypeSymbol type) + { + foreach (var iface in type.AllInterfaces) + { + if (iface.Name == "IEquatable" + && iface.ContainingNamespace?.ToDisplayString() == "System" + && iface.TypeArguments.Length == 1 + && SymbolEqualityComparer.Default.Equals(iface.TypeArguments[0], type)) + { + return true; + } + } + + return false; + } + + private static bool OverridesObjectEquals(ITypeSymbol type) + { + for (var current = type; + current is not null && current.SpecialType != SpecialType.System_Object; + current = current.BaseType) + { + foreach (var member in current.GetMembers("Equals")) + { + if (member is IMethodSymbol { IsOverride: true, Parameters.Length: 1 } m + && m.Parameters[0].Type.SpecialType == SpecialType.System_Object) + { + return true; + } + } + } + + return false; + } +} diff --git a/TUnit.Assertions.Analyzers/Resources.Designer.cs b/TUnit.Assertions.Analyzers/Resources.Designer.cs index 475559350b..5eea35af39 100644 --- a/TUnit.Assertions.Analyzers/Resources.Designer.cs +++ b/TUnit.Assertions.Analyzers/Resources.Designer.cs @@ -316,5 +316,14 @@ internal static string TUnitAssertions0009Title { return ResourceManager.GetString("TUnitAssertions0009Title", resourceCulture); } } + + /// + /// Looks up a localized string similar to Replace `.IsEqualTo(...)` with `.IsEquivalentTo(...)`. + /// + internal static string TUnitAssertions0016CodeFixTitle { + get { + return ResourceManager.GetString("TUnitAssertions0016CodeFixTitle", resourceCulture); + } + } } } diff --git a/TUnit.Assertions.Analyzers/Resources.resx b/TUnit.Assertions.Analyzers/Resources.resx index a4ce78236e..7ddf2513c6 100644 --- a/TUnit.Assertions.Analyzers/Resources.resx +++ b/TUnit.Assertions.Analyzers/Resources.resx @@ -159,4 +159,16 @@ Prefer IsTrue()/IsFalse() over IsEqualTo(true/false) - \ No newline at end of file + + `.IsEqualTo(...)` on a collection uses reference equality because collection types don't override `Equals`. Use `.IsEquivalentTo(...)` to compare contents. + + + `.IsEqualTo(...)` on a collection compares by reference - use `.IsEquivalentTo(...)` to compare contents + + + Collection `.IsEqualTo(...)` compares by reference + + + Replace `.IsEqualTo(...)` with `.IsEquivalentTo(...)` + + diff --git a/TUnit.Assertions.Analyzers/Rules.cs b/TUnit.Assertions.Analyzers/Rules.cs index 0408d33810..b58ef58988 100644 --- a/TUnit.Assertions.Analyzers/Rules.cs +++ b/TUnit.Assertions.Analyzers/Rules.cs @@ -51,6 +51,9 @@ internal static class Rules public static readonly DiagnosticDescriptor PreferIsTrueOrIsFalseOverIsEqualToBool = CreateDescriptor("TUnitAssertions0015", UsageCategory, DiagnosticSeverity.Warning); + public static readonly DiagnosticDescriptor CollectionIsEqualToUsesReferenceEquality = + CreateDescriptor("TUnitAssertions0016", UsageCategory, DiagnosticSeverity.Info); + private static DiagnosticDescriptor CreateDescriptor(string diagnosticId, string category, DiagnosticSeverity severity) { return new DiagnosticDescriptor(