diff --git a/README.md b/README.md
index dedaafbd..e98755ee 100755
--- a/README.md
+++ b/README.md
@@ -200,6 +200,7 @@ If you are already using other analyzers, you can check [which rules are duplica
|[MA0183](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0183.md)|Usage|string.Format should use a format string with placeholders|⚠️|✔️|❌|
|[MA0184](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0184.md)|Style|Do not use interpolated string without parameters|👻|✔️|✔️|
|[MA0185](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0185.md)|Performance|Simplify string.Create when all parameters are culture invariant|ℹ️|✔️|✔️|
+|[MA0186](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0186.md)|Design|Equals method should use \[NotNullWhen(true)\] on the parameter|ℹ️|❌|❌|
diff --git a/docs/README.md b/docs/README.md
index 5aa8d0df..d572e88a 100755
--- a/docs/README.md
+++ b/docs/README.md
@@ -184,6 +184,7 @@
|[MA0183](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0183.md)|Usage|string.Format should use a format string with placeholders|⚠️|✔️|❌|
|[MA0184](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0184.md)|Style|Do not use interpolated string without parameters|👻|✔️|✔️|
|[MA0185](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0185.md)|Performance|Simplify string.Create when all parameters are culture invariant|ℹ️|✔️|✔️|
+|[MA0186](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0186.md)|Design|Equals method should use \[NotNullWhen(true)\] on the parameter|ℹ️|❌|❌|
|Id|Suppressed rule|Justification|
|--|---------------|-------------|
@@ -751,6 +752,9 @@ dotnet_diagnostic.MA0184.severity = silent
# MA0185: Simplify string.Create when all parameters are culture invariant
dotnet_diagnostic.MA0185.severity = suggestion
+
+# MA0186: Equals method should use [NotNullWhen(true)] on the parameter
+dotnet_diagnostic.MA0186.severity = none
```
# .editorconfig - all rules disabled
@@ -1304,4 +1308,7 @@ dotnet_diagnostic.MA0184.severity = none
# MA0185: Simplify string.Create when all parameters are culture invariant
dotnet_diagnostic.MA0185.severity = none
+
+# MA0186: Equals method should use [NotNullWhen(true)] on the parameter
+dotnet_diagnostic.MA0186.severity = none
```
diff --git a/docs/Rules/MA0186.md b/docs/Rules/MA0186.md
new file mode 100644
index 00000000..6a1de676
--- /dev/null
+++ b/docs/Rules/MA0186.md
@@ -0,0 +1,118 @@
+# MA0186 - Equals method should use \[NotNullWhen(true)\] on the parameter
+
+Source: [MissingNotNullWhenAttributeOnEqualsAnalyzer.cs](https://github.com/meziantou/Meziantou.Analyzer/blob/main/src/Meziantou.Analyzer/Rules/MissingNotNullWhenAttributeOnEqualsAnalyzer.cs)
+
+
+## Description
+
+This rule reports missing nullable attributes on method parameters to improve nullable reference type analysis:
+
+1. When implementing `Equals(object?)` or `IEquatable.Equals(T?)`, the parameter should be annotated with `[NotNullWhen(true)]` to inform the compiler that when the method returns `true`, the parameter is guaranteed to be non-null.
+
+2. When implementing `IDictionary.TryGetValue`, the `value` parameter should be annotated with `[MaybeNullWhen(false)]` to inform the compiler that when the method returns `false`, the parameter may be null.
+
+## Motivation
+
+### Equals methods
+
+Without the `[NotNullWhen(true)]` attribute, the compiler cannot track that a successful equality comparison means the parameter is non-null. This can lead to unnecessary null checks or null-forgiving operators in code that follows an equality check.
+
+````c#
+// ❌ Without the attribute
+public override bool Equals(object? obj)
+{
+ return obj is MyType other && /* comparison */;
+}
+
+// Usage
+if (myInstance.Equals(someObj))
+{
+ // Compiler doesn't know someObj is not null here
+ someObj.ToString(); // Warning CS8602
+}
+
+// ✅ With the attribute
+public override bool Equals([NotNullWhen(true)] object? obj)
+{
+ return obj is MyType other && /* comparison */;
+}
+
+// Usage
+if (myInstance.Equals(someObj))
+{
+ // Compiler knows someObj is not null here
+ someObj.ToString(); // No warning
+}
+````
+
+### TryGetValue methods
+
+Without the `[MaybeNullWhen(false)]` attribute, the compiler cannot track that when `TryGetValue` returns `false`, the out parameter may be null.
+
+````c#
+// ❌ Without the attribute
+public bool TryGetValue(TKey key, out TValue? value)
+{
+ // Implementation
+}
+
+// ✅ With the attribute
+public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue? value)
+{
+ // Implementation
+}
+
+// Usage
+if (dictionary.TryGetValue(key, out var value))
+{
+ value.ToString(); // Compiler knows value is not null
+}
+````
+
+## How to fix violations
+
+### For Equals methods
+
+Add the `[NotNullWhen(true)]` attribute to the parameter:
+
+````c#
+using System.Diagnostics.CodeAnalysis;
+
+public class MyType : IEquatable
+{
+ public override bool Equals([NotNullWhen(true)] object? obj)
+ {
+ return Equals(obj as MyType);
+ }
+
+ public bool Equals([NotNullWhen(true)] MyType? other)
+ {
+ return other is not null && /* comparison logic */;
+ }
+}
+````
+
+### For TryGetValue methods
+
+Add the `[MaybeNullWhen(false)]` attribute to the value parameter:
+
+````c#
+using System.Diagnostics.CodeAnalysis;
+using System.Collections.Generic;
+
+public class MyDictionary : IDictionary
+{
+ public bool TryGetValue(string key, [MaybeNullWhen(false)] out string? value)
+ {
+ // Implementation
+ }
+}
+````
+
+## Configuration
+
+This rule is disabled by default as it's a suggestion for improved nullable annotations. To enable it, add the following to your `.editorconfig` file:
+
+````editorconfig
+dotnet_diagnostic.MA0186.severity = suggestion
+````
diff --git a/src/Meziantou.Analyzer.Pack/configuration/default.editorconfig b/src/Meziantou.Analyzer.Pack/configuration/default.editorconfig
index 23e24986..4b5016fe 100644
--- a/src/Meziantou.Analyzer.Pack/configuration/default.editorconfig
+++ b/src/Meziantou.Analyzer.Pack/configuration/default.editorconfig
@@ -550,3 +550,6 @@ dotnet_diagnostic.MA0184.severity = silent
# MA0185: Simplify string.Create when all parameters are culture invariant
dotnet_diagnostic.MA0185.severity = suggestion
+
+# MA0186: Equals method should use [NotNullWhen(true)] on the parameter
+dotnet_diagnostic.MA0186.severity = none
diff --git a/src/Meziantou.Analyzer.Pack/configuration/none.editorconfig b/src/Meziantou.Analyzer.Pack/configuration/none.editorconfig
index bb651f75..09a71638 100644
--- a/src/Meziantou.Analyzer.Pack/configuration/none.editorconfig
+++ b/src/Meziantou.Analyzer.Pack/configuration/none.editorconfig
@@ -550,3 +550,6 @@ dotnet_diagnostic.MA0184.severity = none
# MA0185: Simplify string.Create when all parameters are culture invariant
dotnet_diagnostic.MA0185.severity = none
+
+# MA0186: Equals method should use [NotNullWhen(true)] on the parameter
+dotnet_diagnostic.MA0186.severity = none
diff --git a/src/Meziantou.Analyzer/Internals/TypeSymbolExtensions.cs b/src/Meziantou.Analyzer/Internals/TypeSymbolExtensions.cs
index e92414c4..2ecb89ef 100755
--- a/src/Meziantou.Analyzer/Internals/TypeSymbolExtensions.cs
+++ b/src/Meziantou.Analyzer/Internals/TypeSymbolExtensions.cs
@@ -48,6 +48,14 @@ public static bool Implements(this ITypeSymbol classSymbol, ITypeSymbol? interfa
return classSymbol.AllInterfaces.Any(interfaceType.IsEqualTo);
}
+ public static bool ImplementsGenericInterface(this ITypeSymbol classSymbol, ITypeSymbol? interfaceType)
+ {
+ if (interfaceType is null)
+ return false;
+
+ return classSymbol.AllInterfaces.Any(iface => iface.OriginalDefinition.IsEqualTo(interfaceType.OriginalDefinition));
+ }
+
public static bool IsOrImplements(this ITypeSymbol symbol, ITypeSymbol? interfaceType)
{
if (interfaceType is null)
diff --git a/src/Meziantou.Analyzer/RuleIdentifiers.cs b/src/Meziantou.Analyzer/RuleIdentifiers.cs
index 22d72ced..314aa7f8 100755
--- a/src/Meziantou.Analyzer/RuleIdentifiers.cs
+++ b/src/Meziantou.Analyzer/RuleIdentifiers.cs
@@ -185,6 +185,7 @@ internal static class RuleIdentifiers
public const string StringFormatShouldBeConstant = "MA0183";
public const string DoNotUseInterpolatedStringWithoutParameters = "MA0184";
public const string SimplifyStringCreateWhenAllParametersAreCultureInvariant = "MA0185";
+ public const string MissingNotNullWhenAttributeOnEquals = "MA0186";
public static string GetHelpUri(string identifier)
{
diff --git a/src/Meziantou.Analyzer/Rules/MissingNotNullWhenAttributeOnEqualsAnalyzer.cs b/src/Meziantou.Analyzer/Rules/MissingNotNullWhenAttributeOnEqualsAnalyzer.cs
new file mode 100644
index 00000000..e1949b47
--- /dev/null
+++ b/src/Meziantou.Analyzer/Rules/MissingNotNullWhenAttributeOnEqualsAnalyzer.cs
@@ -0,0 +1,194 @@
+using System.Collections.Immutable;
+using System.Runtime.CompilerServices;
+using Meziantou.Analyzer.Internals;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.Diagnostics;
+
+namespace Meziantou.Analyzer.Rules;
+
+[DiagnosticAnalyzer(LanguageNames.CSharp)]
+public sealed class MissingNotNullWhenAttributeOnEqualsAnalyzer : DiagnosticAnalyzer
+{
+ private static readonly DiagnosticDescriptor EqualsRule = new(
+ RuleIdentifiers.MissingNotNullWhenAttributeOnEquals,
+ title: "Equals method should use [NotNullWhen(true)] on the parameter",
+ messageFormat: "Equals method should use [NotNullWhen(true)] on parameter '{0}'",
+ RuleCategories.Design,
+ DiagnosticSeverity.Info,
+ isEnabledByDefault: false,
+ description: "",
+ helpLinkUri: RuleIdentifiers.GetHelpUri(RuleIdentifiers.MissingNotNullWhenAttributeOnEquals));
+
+ private static readonly DiagnosticDescriptor TryGetValueRule = new(
+ RuleIdentifiers.MissingNotNullWhenAttributeOnEquals,
+ title: "TryGetValue method should use [MaybeNullWhen(false)] on the value parameter",
+ messageFormat: "TryGetValue method should use [MaybeNullWhen(false)] on parameter '{0}'",
+ RuleCategories.Design,
+ DiagnosticSeverity.Info,
+ isEnabledByDefault: false,
+ description: "",
+ helpLinkUri: RuleIdentifiers.GetHelpUri(RuleIdentifiers.MissingNotNullWhenAttributeOnEquals));
+
+ public override ImmutableArray SupportedDiagnostics => ImmutableArray.Create(EqualsRule, TryGetValueRule);
+
+ public override void Initialize(AnalysisContext context)
+ {
+ context.EnableConcurrentExecution();
+ context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);
+
+ context.RegisterCompilationStartAction(ctx =>
+ {
+ var notNullWhenAttributeSymbol = ctx.Compilation.GetBestTypeByMetadataName("System.Diagnostics.CodeAnalysis.NotNullWhenAttribute");
+ var maybeNullWhenAttributeSymbol = ctx.Compilation.GetBestTypeByMetadataName("System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute");
+ var iequatableOfTSymbol = ctx.Compilation.GetBestTypeByMetadataName("System.IEquatable`1");
+ var idictionaryOfTSymbol = ctx.Compilation.GetBestTypeByMetadataName("System.Collections.Generic.IDictionary`2");
+
+ if (idictionaryOfTSymbol != null && maybeNullWhenAttributeSymbol is not null)
+ {
+ var tryGetValueSymbols = idictionaryOfTSymbol.GetMembers("TryGetValue");
+ if (tryGetValueSymbols.Length == 1)
+ {
+ var tryGetValueSymbol = tryGetValueSymbols[0];
+ ctx.RegisterSymbolAction(context =>
+ {
+ var namedType = (INamedTypeSymbol)context.Symbol;
+ foreach (var interfaceType in namedType.AllInterfaces)
+ {
+ if (!interfaceType.ConstructedFrom.IsEqualTo(idictionaryOfTSymbol))
+ continue;
+
+ var dictionaryTryGetValueSymbols = interfaceType.GetMembers("TryGetValue");
+ if (dictionaryTryGetValueSymbols.Length != 1)
+ continue;
+
+ var implementation = namedType.FindImplementationForInterfaceMember(dictionaryTryGetValueSymbols[0]) as IMethodSymbol;
+ if (implementation is null)
+ continue;
+
+ if (implementation.Parameters.Length != 2)
+ continue;
+
+ var valueParameter = implementation.Parameters[1];
+
+ // Check if the parameter is an out parameter
+ if (valueParameter.RefKind != RefKind.Out)
+ continue;
+
+ // Check if the parameter is nullable
+ if (valueParameter.NullableAnnotation != NullableAnnotation.Annotated)
+ continue;
+
+ // Check if the parameter already has [MaybeNullWhen(false)] attribute
+ if (HasMaybeNullWhenAttribute(valueParameter, maybeNullWhenAttributeSymbol, expectedValue: false))
+ continue;
+
+ // Report diagnostic
+ context.ReportDiagnostic(TryGetValueRule, valueParameter, valueParameter.Name);
+ }
+ }, SymbolKind.NamedType);
+ }
+
+
+ }
+
+ if (notNullWhenAttributeSymbol is not null)
+ {
+ context.RegisterSymbolAction(context =>
+ {
+ var method = (IMethodSymbol)context.Symbol;
+ if (method.Name is nameof(object.Equals))
+ {
+ if (!method.ReturnType.IsBoolean())
+ return;
+
+ if (method.Parameters.Length != 1)
+ return;
+
+ if (method.IsStatic)
+ return;
+
+ var parameter = method.Parameters[0];
+
+ // Check if the parameter is nullable, this also ensures nullable annotations are enabled for the parameter
+ if (parameter.NullableAnnotation != NullableAnnotation.Annotated)
+ return;
+
+
+ // Check if it's Equals(object?) override using helper
+ var isObjectEqualsOverride = false;
+ if (method.IsOverride && parameter.Type.IsObject())
+ {
+ // Verify it's overriding object.Equals by checking the base member
+ var currentMethod = method.OverriddenMethod;
+ while (currentMethod is not null)
+ {
+ if (currentMethod.ContainingType.IsObject())
+ {
+ isObjectEqualsOverride = true;
+ break;
+ }
+ currentMethod = currentMethod.OverriddenMethod;
+ }
+ }
+
+ // Check if it's IEquatable.Equals(T?) implementation using helper
+ var isIEquatableEquals = false;
+ if (iequatableOfTSymbol is not null && method.ContainingType is not null && !method.ContainingType.IsValueType)
+ {
+ if (method.IsInterfaceImplementation())
+ {
+ var interfaceMethod = method.GetImplementingInterfaceSymbol();
+ if (interfaceMethod is not null &&
+ interfaceMethod.ContainingType is INamedTypeSymbol interfaceType &&
+ interfaceType.ConstructedFrom.IsEqualTo(iequatableOfTSymbol))
+ {
+ isIEquatableEquals = true;
+ }
+ }
+ }
+
+ if (!isObjectEqualsOverride && !isIEquatableEquals)
+ return;
+
+ // Check if the parameter already has [NotNullWhen(true)] attribute
+ if (HasNotNullWhenAttribute(parameter, notNullWhenAttributeSymbol, expectedValue: true))
+ return;
+
+ // Report diagnostic
+ context.ReportDiagnostic(EqualsRule, parameter, parameter.Name);
+ }
+ }, SymbolKind.Method);
+ }
+ });
+ }
+
+ private static bool HasNotNullWhenAttribute(IParameterSymbol parameter, INamedTypeSymbol notNullWhenAttributeSymbol, bool expectedValue)
+ {
+ foreach (var attribute in parameter.GetAttributes())
+ {
+ if (attribute.AttributeClass.IsEqualTo(notNullWhenAttributeSymbol))
+ {
+ if (attribute.ConstructorArguments.Length == 1 && attribute.ConstructorArguments[0].Value is bool value && value == expectedValue)
+ {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ private static bool HasMaybeNullWhenAttribute(IParameterSymbol parameter, INamedTypeSymbol maybeNullWhenAttributeSymbol, bool expectedValue)
+ {
+ foreach (var attribute in parameter.GetAttributes())
+ {
+ if (attribute.AttributeClass.IsEqualTo(maybeNullWhenAttributeSymbol))
+ {
+ if (attribute.ConstructorArguments.Length == 1 && attribute.ConstructorArguments[0].Value is bool value && value == expectedValue)
+ {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+}
diff --git a/tests/Meziantou.Analyzer.Test/Rules/MissingNotNullWhenAttributeOnEqualsAnalyzerTests.cs b/tests/Meziantou.Analyzer.Test/Rules/MissingNotNullWhenAttributeOnEqualsAnalyzerTests.cs
new file mode 100644
index 00000000..25b2b3f5
--- /dev/null
+++ b/tests/Meziantou.Analyzer.Test/Rules/MissingNotNullWhenAttributeOnEqualsAnalyzerTests.cs
@@ -0,0 +1,666 @@
+using Meziantou.Analyzer.Rules;
+using TestHelper;
+
+namespace Meziantou.Analyzer.Test.Rules;
+
+public sealed class MissingNotNullWhenAttributeOnEqualsAnalyzerTests
+{
+ private static ProjectBuilder CreateProjectBuilder()
+ {
+ return new ProjectBuilder()
+ .WithAnalyzer();
+ }
+
+ [Fact]
+ public async Task Equals_Object_WithoutAttribute_ShouldReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ class Sample
+ {
+ public override bool Equals(object? [|obj|])
+ {
+ return false;
+ }
+
+ public override int GetHashCode() => 0;
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task TryGetValue_IDictionary_ExplicitWithoutAttribute_ShouldReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ using System.Collections.Generic;
+
+ class MyDictionary : IDictionary
+ {
+ bool IDictionary.TryGetValue(string key, out string? [|value|])
+ {
+ value = null;
+ return false;
+ }
+
+ public string? this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); }
+ public ICollection Keys => throw new System.NotImplementedException();
+ public ICollection Values => throw new System.NotImplementedException();
+ public int Count => throw new System.NotImplementedException();
+ public bool IsReadOnly => throw new System.NotImplementedException();
+ public void Add(string key, string? value) => throw new System.NotImplementedException();
+ public void Add(KeyValuePair item) => throw new System.NotImplementedException();
+ public void Clear() => throw new System.NotImplementedException();
+ public bool Contains(KeyValuePair item) => throw new System.NotImplementedException();
+ public bool ContainsKey(string key) => throw new System.NotImplementedException();
+ public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException();
+ public IEnumerator> GetEnumerator() => throw new System.NotImplementedException();
+ public bool Remove(string key) => throw new System.NotImplementedException();
+ public bool Remove(KeyValuePair item) => throw new System.NotImplementedException();
+ System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() => throw new System.NotImplementedException();
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task TryGetValue_IDictionary_ExplicitWithAttribute_ShouldNotReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ using System.Collections.Generic;
+ using System.Diagnostics.CodeAnalysis;
+
+ class MyDictionary : IDictionary
+ {
+ bool IDictionary.TryGetValue(string key, [MaybeNullWhen(false)] out string? value)
+ {
+ value = null;
+ return false;
+ }
+
+ public string? this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); }
+ public ICollection Keys => throw new System.NotImplementedException();
+ public ICollection Values => throw new System.NotImplementedException();
+ public int Count => throw new System.NotImplementedException();
+ public bool IsReadOnly => throw new System.NotImplementedException();
+ public void Add(string key, string? value) => throw new System.NotImplementedException();
+ public void Add(KeyValuePair item) => throw new System.NotImplementedException();
+ public void Clear() => throw new System.NotImplementedException();
+ public bool Contains(KeyValuePair item) => throw new System.NotImplementedException();
+ public bool ContainsKey(string key) => throw new System.NotImplementedException();
+ public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException();
+ public IEnumerator> GetEnumerator() => throw new System.NotImplementedException();
+ public bool Remove(string key) => throw new System.NotImplementedException();
+ public bool Remove(KeyValuePair item) => throw new System.NotImplementedException();
+ System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() => throw new System.NotImplementedException();
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task Equals_Object_WithAttribute_ShouldNotReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ using System.Diagnostics.CodeAnalysis;
+
+ class Sample
+ {
+ public override bool Equals([NotNullWhen(true)] object? obj)
+ {
+ return false;
+ }
+
+ public override int GetHashCode() => 0;
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task Equals_IEquatable_WithoutAttribute_ShouldReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ using System;
+
+ class Sample : IEquatable
+ {
+ public bool Equals(Sample? [|other|])
+ {
+ return false;
+ }
+
+ public override bool Equals(object? [|obj|])
+ {
+ return Equals(obj as Sample);
+ }
+
+ public override int GetHashCode() => 0;
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task Equals_IEquatable_WithAttribute_ShouldNotReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ using System;
+ using System.Diagnostics.CodeAnalysis;
+
+ class Sample : IEquatable
+ {
+ public bool Equals([NotNullWhen(true)] Sample? other)
+ {
+ return false;
+ }
+
+ public override bool Equals([NotNullWhen(true)] object? obj)
+ {
+ return Equals(obj as Sample);
+ }
+
+ public override int GetHashCode() => 0;
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task Equals_NonNullableParameter_ShouldNotReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ class Sample
+ {
+ public override bool Equals(object obj)
+ {
+ return false;
+ }
+
+ public override int GetHashCode() => 0;
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task NotEqualsMethod_ShouldNotReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ class Sample
+ {
+ public bool IsEqual(object? obj)
+ {
+ return false;
+ }
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task PrivateEquals_ShouldNotReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ class Sample
+ {
+ private bool Equals(object? obj)
+ {
+ return false;
+ }
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task StaticEquals_ShouldNotReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ class Sample
+ {
+ public static bool Equals(object? obj1, object? obj2)
+ {
+ return false;
+ }
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task Equals_WrongSignature_ShouldNotReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ class Sample
+ {
+ public bool Equals(object? obj, int x)
+ {
+ return false;
+ }
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task Equals_IEquatable_BothMethodsWithoutAttribute_ShouldReportBothDiagnostics()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ using System;
+
+ class Sample : IEquatable
+ {
+ public bool Equals(Sample? [|other|])
+ {
+ return false;
+ }
+
+ public override bool Equals(object? [|obj|])
+ {
+ return Equals(obj as Sample);
+ }
+
+ public override int GetHashCode() => 0;
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task Equals_IEquatable_ValueType_ShouldNotReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ using System;
+
+ struct Sample : IEquatable
+ {
+ public bool Equals(Sample other)
+ {
+ return false;
+ }
+
+ public override bool Equals(object? [|obj|])
+ {
+ return obj is Sample other && Equals(other);
+ }
+
+ public override int GetHashCode() => 0;
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task Equals_NullableDisabled_ShouldNotReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ #nullable disable
+ class Sample
+ {
+ public override bool Equals(object obj)
+ {
+ return false;
+ }
+
+ public override int GetHashCode() => 0;
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task Equals_NullableEnabled_ShouldReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ #nullable enable
+ class Sample
+ {
+ public override bool Equals(object? [|obj|])
+ {
+ return false;
+ }
+
+ public override int GetHashCode() => 0;
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task Equals_NullableEnabledThenDisabled_ShouldNotReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ #nullable enable
+ class Sample1
+ {
+ public override bool Equals(object? [|obj|])
+ {
+ return false;
+ }
+
+ public override int GetHashCode() => 0;
+ }
+
+ #nullable disable
+ class Sample2
+ {
+ public override bool Equals(object obj)
+ {
+ return false;
+ }
+
+ public override int GetHashCode() => 0;
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task TryGetValue_IDictionary_WithoutAttribute_ShouldReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ using System.Collections.Generic;
+
+ class MyDictionary : IDictionary
+ {
+ public bool TryGetValue(string key, out string? [|value|])
+ {
+ value = null;
+ return false;
+ }
+
+ // Other IDictionary members...
+ public string? this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); }
+ public ICollection Keys => throw new System.NotImplementedException();
+ public ICollection Values => throw new System.NotImplementedException();
+ public int Count => throw new System.NotImplementedException();
+ public bool IsReadOnly => throw new System.NotImplementedException();
+ public void Add(string key, string? value) => throw new System.NotImplementedException();
+ public void Add(KeyValuePair item) => throw new System.NotImplementedException();
+ public void Clear() => throw new System.NotImplementedException();
+ public bool Contains(KeyValuePair item) => throw new System.NotImplementedException();
+ public bool ContainsKey(string key) => throw new System.NotImplementedException();
+ public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException();
+ public IEnumerator> GetEnumerator() => throw new System.NotImplementedException();
+ public bool Remove(string key) => throw new System.NotImplementedException();
+ public bool Remove(KeyValuePair item) => throw new System.NotImplementedException();
+ System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() => throw new System.NotImplementedException();
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task TryGetValue_IDictionary_WithAttribute_ShouldNotReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ using System.Collections.Generic;
+ using System.Diagnostics.CodeAnalysis;
+
+ class MyDictionary : IDictionary
+ {
+ public bool TryGetValue(string key, [MaybeNullWhen(false)] out string? value)
+ {
+ value = null;
+ return false;
+ }
+
+ // Other IDictionary members...
+ public string? this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); }
+ public ICollection Keys => throw new System.NotImplementedException();
+ public ICollection Values => throw new System.NotImplementedException();
+ public int Count => throw new System.NotImplementedException();
+ public bool IsReadOnly => throw new System.NotImplementedException();
+ public void Add(string key, string? value) => throw new System.NotImplementedException();
+ public void Add(KeyValuePair item) => throw new System.NotImplementedException();
+ public void Clear() => throw new System.NotImplementedException();
+ public bool Contains(KeyValuePair item) => throw new System.NotImplementedException();
+ public bool ContainsKey(string key) => throw new System.NotImplementedException();
+ public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException();
+ public IEnumerator> GetEnumerator() => throw new System.NotImplementedException();
+ public bool Remove(string key) => throw new System.NotImplementedException();
+ public bool Remove(KeyValuePair item) => throw new System.NotImplementedException();
+ System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() => throw new System.NotImplementedException();
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task TryGetValue_IDictionary_Twice_BothInvalid_ShouldReportDiagnostics()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ using System.Collections;
+ using System.Collections.Generic;
+
+ class MyDictionary : IDictionary, IDictionary
+ {
+ public bool TryGetValue(string key, out string? [|stringValue|])
+ {
+ stringValue = null;
+ return false;
+ }
+
+ public bool TryGetValue(int key, out string? [|intValue|])
+ {
+ intValue = null;
+ return false;
+ }
+
+ public string? this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); }
+ public string? this[int key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); }
+
+ ICollection IDictionary.Keys => throw new System.NotImplementedException();
+ ICollection IDictionary.Keys => throw new System.NotImplementedException();
+ ICollection IDictionary.Values => throw new System.NotImplementedException();
+ ICollection IDictionary.Values => throw new System.NotImplementedException();
+
+ public int Count => throw new System.NotImplementedException();
+ public bool IsReadOnly => throw new System.NotImplementedException();
+ public void Add(string key, string? value) => throw new System.NotImplementedException();
+ public void Add(int key, string? value) => throw new System.NotImplementedException();
+ public void Add(KeyValuePair item) => throw new System.NotImplementedException();
+ public void Add(KeyValuePair item) => throw new System.NotImplementedException();
+ public void Clear() => throw new System.NotImplementedException();
+ public bool Contains(KeyValuePair item) => throw new System.NotImplementedException();
+ public bool Contains(KeyValuePair item) => throw new System.NotImplementedException();
+ public bool ContainsKey(string key) => throw new System.NotImplementedException();
+ public bool ContainsKey(int key) => throw new System.NotImplementedException();
+ public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException();
+ public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException();
+ public bool Remove(string key) => throw new System.NotImplementedException();
+ public bool Remove(int key) => throw new System.NotImplementedException();
+ public bool Remove(KeyValuePair item) => throw new System.NotImplementedException();
+ public bool Remove(KeyValuePair item) => throw new System.NotImplementedException();
+ IEnumerator> IEnumerable>.GetEnumerator() => throw new System.NotImplementedException();
+ IEnumerator> IEnumerable>.GetEnumerator() => throw new System.NotImplementedException();
+ IEnumerator IEnumerable.GetEnumerator() => throw new System.NotImplementedException();
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task TryGetValue_IDictionary_Twice_OneInvalid_ShouldReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ using System.Collections;
+ using System.Collections.Generic;
+ using System.Diagnostics.CodeAnalysis;
+
+ class MyDictionary : IDictionary, IDictionary
+ {
+ public bool TryGetValue(string key, out string? [|stringValue|])
+ {
+ stringValue = null;
+ return false;
+ }
+
+ public bool TryGetValue(int key, [MaybeNullWhen(false)] out string? intValue)
+ {
+ intValue = null;
+ return false;
+ }
+
+ public string? this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); }
+ public string? this[int key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); }
+
+ ICollection IDictionary.Keys => throw new System.NotImplementedException();
+ ICollection IDictionary.Keys => throw new System.NotImplementedException();
+ ICollection IDictionary.Values => throw new System.NotImplementedException();
+ ICollection IDictionary.Values => throw new System.NotImplementedException();
+
+ public int Count => throw new System.NotImplementedException();
+ public bool IsReadOnly => throw new System.NotImplementedException();
+ public void Add(string key, string? value) => throw new System.NotImplementedException();
+ public void Add(int key, string? value) => throw new System.NotImplementedException();
+ public void Add(KeyValuePair item) => throw new System.NotImplementedException();
+ public void Add(KeyValuePair item) => throw new System.NotImplementedException();
+ public void Clear() => throw new System.NotImplementedException();
+ public bool Contains(KeyValuePair item) => throw new System.NotImplementedException();
+ public bool Contains(KeyValuePair item) => throw new System.NotImplementedException();
+ public bool ContainsKey(string key) => throw new System.NotImplementedException();
+ public bool ContainsKey(int key) => throw new System.NotImplementedException();
+ public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException();
+ public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException();
+ public bool Remove(string key) => throw new System.NotImplementedException();
+ public bool Remove(int key) => throw new System.NotImplementedException();
+ public bool Remove(KeyValuePair item) => throw new System.NotImplementedException();
+ public bool Remove(KeyValuePair item) => throw new System.NotImplementedException();
+ IEnumerator> IEnumerable>.GetEnumerator() => throw new System.NotImplementedException();
+ IEnumerator> IEnumerable>.GetEnumerator() => throw new System.NotImplementedException();
+ IEnumerator IEnumerable.GetEnumerator() => throw new System.NotImplementedException();
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task TryGetValue_IDictionary_Twice_NoneInvalid_ShouldNotReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ using System.Collections;
+ using System.Collections.Generic;
+ using System.Diagnostics.CodeAnalysis;
+
+ class MyDictionary : IDictionary, IDictionary
+ {
+ public bool TryGetValue(string key, [MaybeNullWhen(false)] out string? stringValue)
+ {
+ stringValue = null;
+ return false;
+ }
+
+ public bool TryGetValue(int key, [MaybeNullWhen(false)] out string? intValue)
+ {
+ intValue = null;
+ return false;
+ }
+
+ public string? this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); }
+ public string? this[int key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); }
+
+ ICollection IDictionary.Keys => throw new System.NotImplementedException();
+ ICollection IDictionary.Keys => throw new System.NotImplementedException();
+ ICollection IDictionary.Values => throw new System.NotImplementedException();
+ ICollection IDictionary.Values => throw new System.NotImplementedException();
+
+ public int Count => throw new System.NotImplementedException();
+ public bool IsReadOnly => throw new System.NotImplementedException();
+ public void Add(string key, string? value) => throw new System.NotImplementedException();
+ public void Add(int key, string? value) => throw new System.NotImplementedException();
+ public void Add(KeyValuePair item) => throw new System.NotImplementedException();
+ public void Add(KeyValuePair item) => throw new System.NotImplementedException();
+ public void Clear() => throw new System.NotImplementedException();
+ public bool Contains(KeyValuePair item) => throw new System.NotImplementedException();
+ public bool Contains(KeyValuePair item) => throw new System.NotImplementedException();
+ public bool ContainsKey(string key) => throw new System.NotImplementedException();
+ public bool ContainsKey(int key) => throw new System.NotImplementedException();
+ public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException();
+ public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException();
+ public bool Remove(string key) => throw new System.NotImplementedException();
+ public bool Remove(int key) => throw new System.NotImplementedException();
+ public bool Remove(KeyValuePair item) => throw new System.NotImplementedException();
+ public bool Remove(KeyValuePair item) => throw new System.NotImplementedException();
+ IEnumerator> IEnumerable>.GetEnumerator() => throw new System.NotImplementedException();
+ IEnumerator> IEnumerable>.GetEnumerator() => throw new System.NotImplementedException();
+ IEnumerator IEnumerable.GetEnumerator() => throw new System.NotImplementedException();
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task TryGetValue_NotIDictionary_ShouldNotReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ class MyClass
+ {
+ public bool TryGetValue(string key, out string? value)
+ {
+ value = null;
+ return false;
+ }
+ }
+ """)
+ .ValidateAsync();
+ }
+
+ [Fact]
+ public async Task TryGetValue_NonNullableValue_ShouldNotReportDiagnostic()
+ {
+ await CreateProjectBuilder()
+ .WithSourceCode("""
+ using System.Collections.Generic;
+
+ class MyDictionary : IDictionary
+ {
+ public bool TryGetValue(string key, out string value)
+ {
+ value = "";
+ return false;
+ }
+
+ // Other IDictionary members...
+ public string this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); }
+ public ICollection Keys => throw new System.NotImplementedException();
+ public ICollection Values => throw new System.NotImplementedException();
+ public int Count => throw new System.NotImplementedException();
+ public bool IsReadOnly => throw new System.NotImplementedException();
+ public void Add(string key, string value) => throw new System.NotImplementedException();
+ public void Add(KeyValuePair item) => throw new System.NotImplementedException();
+ public void Clear() => throw new System.NotImplementedException();
+ public bool Contains(KeyValuePair item) => throw new System.NotImplementedException();
+ public bool ContainsKey(string key) => throw new System.NotImplementedException();
+ public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException();
+ public IEnumerator> GetEnumerator() => throw new System.NotImplementedException();
+ public bool Remove(string key) => throw new System.NotImplementedException();
+ public bool Remove(KeyValuePair item) => throw new System.NotImplementedException();
+ System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() => throw new System.NotImplementedException();
+ }
+ """)
+ .ValidateAsync();
+ }
+}