diff --git a/README.md b/README.md index dedaafbd..e98755ee 100755 --- a/README.md +++ b/README.md @@ -200,6 +200,7 @@ If you are already using other analyzers, you can check [which rules are duplica |[MA0183](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0183.md)|Usage|string.Format should use a format string with placeholders|⚠️|✔️|❌| |[MA0184](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0184.md)|Style|Do not use interpolated string without parameters|👻|✔️|✔️| |[MA0185](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0185.md)|Performance|Simplify string.Create when all parameters are culture invariant|ℹ️|✔️|✔️| +|[MA0186](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0186.md)|Design|Equals method should use \[NotNullWhen(true)\] on the parameter|ℹ️|❌|❌| diff --git a/docs/README.md b/docs/README.md index 5aa8d0df..d572e88a 100755 --- a/docs/README.md +++ b/docs/README.md @@ -184,6 +184,7 @@ |[MA0183](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0183.md)|Usage|string.Format should use a format string with placeholders|⚠️|✔️|❌| |[MA0184](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0184.md)|Style|Do not use interpolated string without parameters|👻|✔️|✔️| |[MA0185](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0185.md)|Performance|Simplify string.Create when all parameters are culture invariant|ℹ️|✔️|✔️| +|[MA0186](https://github.com/meziantou/Meziantou.Analyzer/blob/main/docs/Rules/MA0186.md)|Design|Equals method should use \[NotNullWhen(true)\] on the parameter|ℹ️|❌|❌| |Id|Suppressed rule|Justification| |--|---------------|-------------| @@ -751,6 +752,9 @@ dotnet_diagnostic.MA0184.severity = silent # MA0185: Simplify string.Create when all parameters are culture invariant dotnet_diagnostic.MA0185.severity = suggestion + +# MA0186: Equals method should use [NotNullWhen(true)] on the parameter +dotnet_diagnostic.MA0186.severity = none ``` # .editorconfig - all rules disabled @@ -1304,4 +1308,7 @@ dotnet_diagnostic.MA0184.severity = none # MA0185: Simplify string.Create when all parameters are culture invariant dotnet_diagnostic.MA0185.severity = none + +# MA0186: Equals method should use [NotNullWhen(true)] on the parameter +dotnet_diagnostic.MA0186.severity = none ``` diff --git a/docs/Rules/MA0186.md b/docs/Rules/MA0186.md new file mode 100644 index 00000000..6a1de676 --- /dev/null +++ b/docs/Rules/MA0186.md @@ -0,0 +1,118 @@ +# MA0186 - Equals method should use \[NotNullWhen(true)\] on the parameter + +Source: [MissingNotNullWhenAttributeOnEqualsAnalyzer.cs](https://github.com/meziantou/Meziantou.Analyzer/blob/main/src/Meziantou.Analyzer/Rules/MissingNotNullWhenAttributeOnEqualsAnalyzer.cs) + + +## Description + +This rule reports missing nullable attributes on method parameters to improve nullable reference type analysis: + +1. When implementing `Equals(object?)` or `IEquatable.Equals(T?)`, the parameter should be annotated with `[NotNullWhen(true)]` to inform the compiler that when the method returns `true`, the parameter is guaranteed to be non-null. + +2. When implementing `IDictionary.TryGetValue`, the `value` parameter should be annotated with `[MaybeNullWhen(false)]` to inform the compiler that when the method returns `false`, the parameter may be null. + +## Motivation + +### Equals methods + +Without the `[NotNullWhen(true)]` attribute, the compiler cannot track that a successful equality comparison means the parameter is non-null. This can lead to unnecessary null checks or null-forgiving operators in code that follows an equality check. + +````c# +// ❌ Without the attribute +public override bool Equals(object? obj) +{ + return obj is MyType other && /* comparison */; +} + +// Usage +if (myInstance.Equals(someObj)) +{ + // Compiler doesn't know someObj is not null here + someObj.ToString(); // Warning CS8602 +} + +// ✅ With the attribute +public override bool Equals([NotNullWhen(true)] object? obj) +{ + return obj is MyType other && /* comparison */; +} + +// Usage +if (myInstance.Equals(someObj)) +{ + // Compiler knows someObj is not null here + someObj.ToString(); // No warning +} +```` + +### TryGetValue methods + +Without the `[MaybeNullWhen(false)]` attribute, the compiler cannot track that when `TryGetValue` returns `false`, the out parameter may be null. + +````c# +// ❌ Without the attribute +public bool TryGetValue(TKey key, out TValue? value) +{ + // Implementation +} + +// ✅ With the attribute +public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue? value) +{ + // Implementation +} + +// Usage +if (dictionary.TryGetValue(key, out var value)) +{ + value.ToString(); // Compiler knows value is not null +} +```` + +## How to fix violations + +### For Equals methods + +Add the `[NotNullWhen(true)]` attribute to the parameter: + +````c# +using System.Diagnostics.CodeAnalysis; + +public class MyType : IEquatable +{ + public override bool Equals([NotNullWhen(true)] object? obj) + { + return Equals(obj as MyType); + } + + public bool Equals([NotNullWhen(true)] MyType? other) + { + return other is not null && /* comparison logic */; + } +} +```` + +### For TryGetValue methods + +Add the `[MaybeNullWhen(false)]` attribute to the value parameter: + +````c# +using System.Diagnostics.CodeAnalysis; +using System.Collections.Generic; + +public class MyDictionary : IDictionary +{ + public bool TryGetValue(string key, [MaybeNullWhen(false)] out string? value) + { + // Implementation + } +} +```` + +## Configuration + +This rule is disabled by default as it's a suggestion for improved nullable annotations. To enable it, add the following to your `.editorconfig` file: + +````editorconfig +dotnet_diagnostic.MA0186.severity = suggestion +```` diff --git a/src/Meziantou.Analyzer.Pack/configuration/default.editorconfig b/src/Meziantou.Analyzer.Pack/configuration/default.editorconfig index 23e24986..4b5016fe 100644 --- a/src/Meziantou.Analyzer.Pack/configuration/default.editorconfig +++ b/src/Meziantou.Analyzer.Pack/configuration/default.editorconfig @@ -550,3 +550,6 @@ dotnet_diagnostic.MA0184.severity = silent # MA0185: Simplify string.Create when all parameters are culture invariant dotnet_diagnostic.MA0185.severity = suggestion + +# MA0186: Equals method should use [NotNullWhen(true)] on the parameter +dotnet_diagnostic.MA0186.severity = none diff --git a/src/Meziantou.Analyzer.Pack/configuration/none.editorconfig b/src/Meziantou.Analyzer.Pack/configuration/none.editorconfig index bb651f75..09a71638 100644 --- a/src/Meziantou.Analyzer.Pack/configuration/none.editorconfig +++ b/src/Meziantou.Analyzer.Pack/configuration/none.editorconfig @@ -550,3 +550,6 @@ dotnet_diagnostic.MA0184.severity = none # MA0185: Simplify string.Create when all parameters are culture invariant dotnet_diagnostic.MA0185.severity = none + +# MA0186: Equals method should use [NotNullWhen(true)] on the parameter +dotnet_diagnostic.MA0186.severity = none diff --git a/src/Meziantou.Analyzer/Internals/TypeSymbolExtensions.cs b/src/Meziantou.Analyzer/Internals/TypeSymbolExtensions.cs index e92414c4..2ecb89ef 100755 --- a/src/Meziantou.Analyzer/Internals/TypeSymbolExtensions.cs +++ b/src/Meziantou.Analyzer/Internals/TypeSymbolExtensions.cs @@ -48,6 +48,14 @@ public static bool Implements(this ITypeSymbol classSymbol, ITypeSymbol? interfa return classSymbol.AllInterfaces.Any(interfaceType.IsEqualTo); } + public static bool ImplementsGenericInterface(this ITypeSymbol classSymbol, ITypeSymbol? interfaceType) + { + if (interfaceType is null) + return false; + + return classSymbol.AllInterfaces.Any(iface => iface.OriginalDefinition.IsEqualTo(interfaceType.OriginalDefinition)); + } + public static bool IsOrImplements(this ITypeSymbol symbol, ITypeSymbol? interfaceType) { if (interfaceType is null) diff --git a/src/Meziantou.Analyzer/RuleIdentifiers.cs b/src/Meziantou.Analyzer/RuleIdentifiers.cs index 22d72ced..314aa7f8 100755 --- a/src/Meziantou.Analyzer/RuleIdentifiers.cs +++ b/src/Meziantou.Analyzer/RuleIdentifiers.cs @@ -185,6 +185,7 @@ internal static class RuleIdentifiers public const string StringFormatShouldBeConstant = "MA0183"; public const string DoNotUseInterpolatedStringWithoutParameters = "MA0184"; public const string SimplifyStringCreateWhenAllParametersAreCultureInvariant = "MA0185"; + public const string MissingNotNullWhenAttributeOnEquals = "MA0186"; public static string GetHelpUri(string identifier) { diff --git a/src/Meziantou.Analyzer/Rules/MissingNotNullWhenAttributeOnEqualsAnalyzer.cs b/src/Meziantou.Analyzer/Rules/MissingNotNullWhenAttributeOnEqualsAnalyzer.cs new file mode 100644 index 00000000..e1949b47 --- /dev/null +++ b/src/Meziantou.Analyzer/Rules/MissingNotNullWhenAttributeOnEqualsAnalyzer.cs @@ -0,0 +1,194 @@ +using System.Collections.Immutable; +using System.Runtime.CompilerServices; +using Meziantou.Analyzer.Internals; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Diagnostics; + +namespace Meziantou.Analyzer.Rules; + +[DiagnosticAnalyzer(LanguageNames.CSharp)] +public sealed class MissingNotNullWhenAttributeOnEqualsAnalyzer : DiagnosticAnalyzer +{ + private static readonly DiagnosticDescriptor EqualsRule = new( + RuleIdentifiers.MissingNotNullWhenAttributeOnEquals, + title: "Equals method should use [NotNullWhen(true)] on the parameter", + messageFormat: "Equals method should use [NotNullWhen(true)] on parameter '{0}'", + RuleCategories.Design, + DiagnosticSeverity.Info, + isEnabledByDefault: false, + description: "", + helpLinkUri: RuleIdentifiers.GetHelpUri(RuleIdentifiers.MissingNotNullWhenAttributeOnEquals)); + + private static readonly DiagnosticDescriptor TryGetValueRule = new( + RuleIdentifiers.MissingNotNullWhenAttributeOnEquals, + title: "TryGetValue method should use [MaybeNullWhen(false)] on the value parameter", + messageFormat: "TryGetValue method should use [MaybeNullWhen(false)] on parameter '{0}'", + RuleCategories.Design, + DiagnosticSeverity.Info, + isEnabledByDefault: false, + description: "", + helpLinkUri: RuleIdentifiers.GetHelpUri(RuleIdentifiers.MissingNotNullWhenAttributeOnEquals)); + + public override ImmutableArray SupportedDiagnostics => ImmutableArray.Create(EqualsRule, TryGetValueRule); + + public override void Initialize(AnalysisContext context) + { + context.EnableConcurrentExecution(); + context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None); + + context.RegisterCompilationStartAction(ctx => + { + var notNullWhenAttributeSymbol = ctx.Compilation.GetBestTypeByMetadataName("System.Diagnostics.CodeAnalysis.NotNullWhenAttribute"); + var maybeNullWhenAttributeSymbol = ctx.Compilation.GetBestTypeByMetadataName("System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute"); + var iequatableOfTSymbol = ctx.Compilation.GetBestTypeByMetadataName("System.IEquatable`1"); + var idictionaryOfTSymbol = ctx.Compilation.GetBestTypeByMetadataName("System.Collections.Generic.IDictionary`2"); + + if (idictionaryOfTSymbol != null && maybeNullWhenAttributeSymbol is not null) + { + var tryGetValueSymbols = idictionaryOfTSymbol.GetMembers("TryGetValue"); + if (tryGetValueSymbols.Length == 1) + { + var tryGetValueSymbol = tryGetValueSymbols[0]; + ctx.RegisterSymbolAction(context => + { + var namedType = (INamedTypeSymbol)context.Symbol; + foreach (var interfaceType in namedType.AllInterfaces) + { + if (!interfaceType.ConstructedFrom.IsEqualTo(idictionaryOfTSymbol)) + continue; + + var dictionaryTryGetValueSymbols = interfaceType.GetMembers("TryGetValue"); + if (dictionaryTryGetValueSymbols.Length != 1) + continue; + + var implementation = namedType.FindImplementationForInterfaceMember(dictionaryTryGetValueSymbols[0]) as IMethodSymbol; + if (implementation is null) + continue; + + if (implementation.Parameters.Length != 2) + continue; + + var valueParameter = implementation.Parameters[1]; + + // Check if the parameter is an out parameter + if (valueParameter.RefKind != RefKind.Out) + continue; + + // Check if the parameter is nullable + if (valueParameter.NullableAnnotation != NullableAnnotation.Annotated) + continue; + + // Check if the parameter already has [MaybeNullWhen(false)] attribute + if (HasMaybeNullWhenAttribute(valueParameter, maybeNullWhenAttributeSymbol, expectedValue: false)) + continue; + + // Report diagnostic + context.ReportDiagnostic(TryGetValueRule, valueParameter, valueParameter.Name); + } + }, SymbolKind.NamedType); + } + + + } + + if (notNullWhenAttributeSymbol is not null) + { + context.RegisterSymbolAction(context => + { + var method = (IMethodSymbol)context.Symbol; + if (method.Name is nameof(object.Equals)) + { + if (!method.ReturnType.IsBoolean()) + return; + + if (method.Parameters.Length != 1) + return; + + if (method.IsStatic) + return; + + var parameter = method.Parameters[0]; + + // Check if the parameter is nullable, this also ensures nullable annotations are enabled for the parameter + if (parameter.NullableAnnotation != NullableAnnotation.Annotated) + return; + + + // Check if it's Equals(object?) override using helper + var isObjectEqualsOverride = false; + if (method.IsOverride && parameter.Type.IsObject()) + { + // Verify it's overriding object.Equals by checking the base member + var currentMethod = method.OverriddenMethod; + while (currentMethod is not null) + { + if (currentMethod.ContainingType.IsObject()) + { + isObjectEqualsOverride = true; + break; + } + currentMethod = currentMethod.OverriddenMethod; + } + } + + // Check if it's IEquatable.Equals(T?) implementation using helper + var isIEquatableEquals = false; + if (iequatableOfTSymbol is not null && method.ContainingType is not null && !method.ContainingType.IsValueType) + { + if (method.IsInterfaceImplementation()) + { + var interfaceMethod = method.GetImplementingInterfaceSymbol(); + if (interfaceMethod is not null && + interfaceMethod.ContainingType is INamedTypeSymbol interfaceType && + interfaceType.ConstructedFrom.IsEqualTo(iequatableOfTSymbol)) + { + isIEquatableEquals = true; + } + } + } + + if (!isObjectEqualsOverride && !isIEquatableEquals) + return; + + // Check if the parameter already has [NotNullWhen(true)] attribute + if (HasNotNullWhenAttribute(parameter, notNullWhenAttributeSymbol, expectedValue: true)) + return; + + // Report diagnostic + context.ReportDiagnostic(EqualsRule, parameter, parameter.Name); + } + }, SymbolKind.Method); + } + }); + } + + private static bool HasNotNullWhenAttribute(IParameterSymbol parameter, INamedTypeSymbol notNullWhenAttributeSymbol, bool expectedValue) + { + foreach (var attribute in parameter.GetAttributes()) + { + if (attribute.AttributeClass.IsEqualTo(notNullWhenAttributeSymbol)) + { + if (attribute.ConstructorArguments.Length == 1 && attribute.ConstructorArguments[0].Value is bool value && value == expectedValue) + { + return true; + } + } + } + return false; + } + + private static bool HasMaybeNullWhenAttribute(IParameterSymbol parameter, INamedTypeSymbol maybeNullWhenAttributeSymbol, bool expectedValue) + { + foreach (var attribute in parameter.GetAttributes()) + { + if (attribute.AttributeClass.IsEqualTo(maybeNullWhenAttributeSymbol)) + { + if (attribute.ConstructorArguments.Length == 1 && attribute.ConstructorArguments[0].Value is bool value && value == expectedValue) + { + return true; + } + } + } + return false; + } +} diff --git a/tests/Meziantou.Analyzer.Test/Rules/MissingNotNullWhenAttributeOnEqualsAnalyzerTests.cs b/tests/Meziantou.Analyzer.Test/Rules/MissingNotNullWhenAttributeOnEqualsAnalyzerTests.cs new file mode 100644 index 00000000..25b2b3f5 --- /dev/null +++ b/tests/Meziantou.Analyzer.Test/Rules/MissingNotNullWhenAttributeOnEqualsAnalyzerTests.cs @@ -0,0 +1,666 @@ +using Meziantou.Analyzer.Rules; +using TestHelper; + +namespace Meziantou.Analyzer.Test.Rules; + +public sealed class MissingNotNullWhenAttributeOnEqualsAnalyzerTests +{ + private static ProjectBuilder CreateProjectBuilder() + { + return new ProjectBuilder() + .WithAnalyzer(); + } + + [Fact] + public async Task Equals_Object_WithoutAttribute_ShouldReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + class Sample + { + public override bool Equals(object? [|obj|]) + { + return false; + } + + public override int GetHashCode() => 0; + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task TryGetValue_IDictionary_ExplicitWithoutAttribute_ShouldReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + using System.Collections.Generic; + + class MyDictionary : IDictionary + { + bool IDictionary.TryGetValue(string key, out string? [|value|]) + { + value = null; + return false; + } + + public string? this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); } + public ICollection Keys => throw new System.NotImplementedException(); + public ICollection Values => throw new System.NotImplementedException(); + public int Count => throw new System.NotImplementedException(); + public bool IsReadOnly => throw new System.NotImplementedException(); + public void Add(string key, string? value) => throw new System.NotImplementedException(); + public void Add(KeyValuePair item) => throw new System.NotImplementedException(); + public void Clear() => throw new System.NotImplementedException(); + public bool Contains(KeyValuePair item) => throw new System.NotImplementedException(); + public bool ContainsKey(string key) => throw new System.NotImplementedException(); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException(); + public IEnumerator> GetEnumerator() => throw new System.NotImplementedException(); + public bool Remove(string key) => throw new System.NotImplementedException(); + public bool Remove(KeyValuePair item) => throw new System.NotImplementedException(); + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() => throw new System.NotImplementedException(); + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task TryGetValue_IDictionary_ExplicitWithAttribute_ShouldNotReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + using System.Collections.Generic; + using System.Diagnostics.CodeAnalysis; + + class MyDictionary : IDictionary + { + bool IDictionary.TryGetValue(string key, [MaybeNullWhen(false)] out string? value) + { + value = null; + return false; + } + + public string? this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); } + public ICollection Keys => throw new System.NotImplementedException(); + public ICollection Values => throw new System.NotImplementedException(); + public int Count => throw new System.NotImplementedException(); + public bool IsReadOnly => throw new System.NotImplementedException(); + public void Add(string key, string? value) => throw new System.NotImplementedException(); + public void Add(KeyValuePair item) => throw new System.NotImplementedException(); + public void Clear() => throw new System.NotImplementedException(); + public bool Contains(KeyValuePair item) => throw new System.NotImplementedException(); + public bool ContainsKey(string key) => throw new System.NotImplementedException(); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException(); + public IEnumerator> GetEnumerator() => throw new System.NotImplementedException(); + public bool Remove(string key) => throw new System.NotImplementedException(); + public bool Remove(KeyValuePair item) => throw new System.NotImplementedException(); + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() => throw new System.NotImplementedException(); + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task Equals_Object_WithAttribute_ShouldNotReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + using System.Diagnostics.CodeAnalysis; + + class Sample + { + public override bool Equals([NotNullWhen(true)] object? obj) + { + return false; + } + + public override int GetHashCode() => 0; + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task Equals_IEquatable_WithoutAttribute_ShouldReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + using System; + + class Sample : IEquatable + { + public bool Equals(Sample? [|other|]) + { + return false; + } + + public override bool Equals(object? [|obj|]) + { + return Equals(obj as Sample); + } + + public override int GetHashCode() => 0; + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task Equals_IEquatable_WithAttribute_ShouldNotReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + using System; + using System.Diagnostics.CodeAnalysis; + + class Sample : IEquatable + { + public bool Equals([NotNullWhen(true)] Sample? other) + { + return false; + } + + public override bool Equals([NotNullWhen(true)] object? obj) + { + return Equals(obj as Sample); + } + + public override int GetHashCode() => 0; + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task Equals_NonNullableParameter_ShouldNotReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + class Sample + { + public override bool Equals(object obj) + { + return false; + } + + public override int GetHashCode() => 0; + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task NotEqualsMethod_ShouldNotReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + class Sample + { + public bool IsEqual(object? obj) + { + return false; + } + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task PrivateEquals_ShouldNotReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + class Sample + { + private bool Equals(object? obj) + { + return false; + } + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task StaticEquals_ShouldNotReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + class Sample + { + public static bool Equals(object? obj1, object? obj2) + { + return false; + } + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task Equals_WrongSignature_ShouldNotReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + class Sample + { + public bool Equals(object? obj, int x) + { + return false; + } + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task Equals_IEquatable_BothMethodsWithoutAttribute_ShouldReportBothDiagnostics() + { + await CreateProjectBuilder() + .WithSourceCode(""" + using System; + + class Sample : IEquatable + { + public bool Equals(Sample? [|other|]) + { + return false; + } + + public override bool Equals(object? [|obj|]) + { + return Equals(obj as Sample); + } + + public override int GetHashCode() => 0; + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task Equals_IEquatable_ValueType_ShouldNotReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + using System; + + struct Sample : IEquatable + { + public bool Equals(Sample other) + { + return false; + } + + public override bool Equals(object? [|obj|]) + { + return obj is Sample other && Equals(other); + } + + public override int GetHashCode() => 0; + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task Equals_NullableDisabled_ShouldNotReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + #nullable disable + class Sample + { + public override bool Equals(object obj) + { + return false; + } + + public override int GetHashCode() => 0; + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task Equals_NullableEnabled_ShouldReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + #nullable enable + class Sample + { + public override bool Equals(object? [|obj|]) + { + return false; + } + + public override int GetHashCode() => 0; + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task Equals_NullableEnabledThenDisabled_ShouldNotReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + #nullable enable + class Sample1 + { + public override bool Equals(object? [|obj|]) + { + return false; + } + + public override int GetHashCode() => 0; + } + + #nullable disable + class Sample2 + { + public override bool Equals(object obj) + { + return false; + } + + public override int GetHashCode() => 0; + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task TryGetValue_IDictionary_WithoutAttribute_ShouldReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + using System.Collections.Generic; + + class MyDictionary : IDictionary + { + public bool TryGetValue(string key, out string? [|value|]) + { + value = null; + return false; + } + + // Other IDictionary members... + public string? this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); } + public ICollection Keys => throw new System.NotImplementedException(); + public ICollection Values => throw new System.NotImplementedException(); + public int Count => throw new System.NotImplementedException(); + public bool IsReadOnly => throw new System.NotImplementedException(); + public void Add(string key, string? value) => throw new System.NotImplementedException(); + public void Add(KeyValuePair item) => throw new System.NotImplementedException(); + public void Clear() => throw new System.NotImplementedException(); + public bool Contains(KeyValuePair item) => throw new System.NotImplementedException(); + public bool ContainsKey(string key) => throw new System.NotImplementedException(); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException(); + public IEnumerator> GetEnumerator() => throw new System.NotImplementedException(); + public bool Remove(string key) => throw new System.NotImplementedException(); + public bool Remove(KeyValuePair item) => throw new System.NotImplementedException(); + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() => throw new System.NotImplementedException(); + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task TryGetValue_IDictionary_WithAttribute_ShouldNotReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + using System.Collections.Generic; + using System.Diagnostics.CodeAnalysis; + + class MyDictionary : IDictionary + { + public bool TryGetValue(string key, [MaybeNullWhen(false)] out string? value) + { + value = null; + return false; + } + + // Other IDictionary members... + public string? this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); } + public ICollection Keys => throw new System.NotImplementedException(); + public ICollection Values => throw new System.NotImplementedException(); + public int Count => throw new System.NotImplementedException(); + public bool IsReadOnly => throw new System.NotImplementedException(); + public void Add(string key, string? value) => throw new System.NotImplementedException(); + public void Add(KeyValuePair item) => throw new System.NotImplementedException(); + public void Clear() => throw new System.NotImplementedException(); + public bool Contains(KeyValuePair item) => throw new System.NotImplementedException(); + public bool ContainsKey(string key) => throw new System.NotImplementedException(); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException(); + public IEnumerator> GetEnumerator() => throw new System.NotImplementedException(); + public bool Remove(string key) => throw new System.NotImplementedException(); + public bool Remove(KeyValuePair item) => throw new System.NotImplementedException(); + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() => throw new System.NotImplementedException(); + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task TryGetValue_IDictionary_Twice_BothInvalid_ShouldReportDiagnostics() + { + await CreateProjectBuilder() + .WithSourceCode(""" + using System.Collections; + using System.Collections.Generic; + + class MyDictionary : IDictionary, IDictionary + { + public bool TryGetValue(string key, out string? [|stringValue|]) + { + stringValue = null; + return false; + } + + public bool TryGetValue(int key, out string? [|intValue|]) + { + intValue = null; + return false; + } + + public string? this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); } + public string? this[int key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); } + + ICollection IDictionary.Keys => throw new System.NotImplementedException(); + ICollection IDictionary.Keys => throw new System.NotImplementedException(); + ICollection IDictionary.Values => throw new System.NotImplementedException(); + ICollection IDictionary.Values => throw new System.NotImplementedException(); + + public int Count => throw new System.NotImplementedException(); + public bool IsReadOnly => throw new System.NotImplementedException(); + public void Add(string key, string? value) => throw new System.NotImplementedException(); + public void Add(int key, string? value) => throw new System.NotImplementedException(); + public void Add(KeyValuePair item) => throw new System.NotImplementedException(); + public void Add(KeyValuePair item) => throw new System.NotImplementedException(); + public void Clear() => throw new System.NotImplementedException(); + public bool Contains(KeyValuePair item) => throw new System.NotImplementedException(); + public bool Contains(KeyValuePair item) => throw new System.NotImplementedException(); + public bool ContainsKey(string key) => throw new System.NotImplementedException(); + public bool ContainsKey(int key) => throw new System.NotImplementedException(); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException(); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException(); + public bool Remove(string key) => throw new System.NotImplementedException(); + public bool Remove(int key) => throw new System.NotImplementedException(); + public bool Remove(KeyValuePair item) => throw new System.NotImplementedException(); + public bool Remove(KeyValuePair item) => throw new System.NotImplementedException(); + IEnumerator> IEnumerable>.GetEnumerator() => throw new System.NotImplementedException(); + IEnumerator> IEnumerable>.GetEnumerator() => throw new System.NotImplementedException(); + IEnumerator IEnumerable.GetEnumerator() => throw new System.NotImplementedException(); + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task TryGetValue_IDictionary_Twice_OneInvalid_ShouldReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + using System.Collections; + using System.Collections.Generic; + using System.Diagnostics.CodeAnalysis; + + class MyDictionary : IDictionary, IDictionary + { + public bool TryGetValue(string key, out string? [|stringValue|]) + { + stringValue = null; + return false; + } + + public bool TryGetValue(int key, [MaybeNullWhen(false)] out string? intValue) + { + intValue = null; + return false; + } + + public string? this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); } + public string? this[int key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); } + + ICollection IDictionary.Keys => throw new System.NotImplementedException(); + ICollection IDictionary.Keys => throw new System.NotImplementedException(); + ICollection IDictionary.Values => throw new System.NotImplementedException(); + ICollection IDictionary.Values => throw new System.NotImplementedException(); + + public int Count => throw new System.NotImplementedException(); + public bool IsReadOnly => throw new System.NotImplementedException(); + public void Add(string key, string? value) => throw new System.NotImplementedException(); + public void Add(int key, string? value) => throw new System.NotImplementedException(); + public void Add(KeyValuePair item) => throw new System.NotImplementedException(); + public void Add(KeyValuePair item) => throw new System.NotImplementedException(); + public void Clear() => throw new System.NotImplementedException(); + public bool Contains(KeyValuePair item) => throw new System.NotImplementedException(); + public bool Contains(KeyValuePair item) => throw new System.NotImplementedException(); + public bool ContainsKey(string key) => throw new System.NotImplementedException(); + public bool ContainsKey(int key) => throw new System.NotImplementedException(); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException(); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException(); + public bool Remove(string key) => throw new System.NotImplementedException(); + public bool Remove(int key) => throw new System.NotImplementedException(); + public bool Remove(KeyValuePair item) => throw new System.NotImplementedException(); + public bool Remove(KeyValuePair item) => throw new System.NotImplementedException(); + IEnumerator> IEnumerable>.GetEnumerator() => throw new System.NotImplementedException(); + IEnumerator> IEnumerable>.GetEnumerator() => throw new System.NotImplementedException(); + IEnumerator IEnumerable.GetEnumerator() => throw new System.NotImplementedException(); + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task TryGetValue_IDictionary_Twice_NoneInvalid_ShouldNotReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + using System.Collections; + using System.Collections.Generic; + using System.Diagnostics.CodeAnalysis; + + class MyDictionary : IDictionary, IDictionary + { + public bool TryGetValue(string key, [MaybeNullWhen(false)] out string? stringValue) + { + stringValue = null; + return false; + } + + public bool TryGetValue(int key, [MaybeNullWhen(false)] out string? intValue) + { + intValue = null; + return false; + } + + public string? this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); } + public string? this[int key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); } + + ICollection IDictionary.Keys => throw new System.NotImplementedException(); + ICollection IDictionary.Keys => throw new System.NotImplementedException(); + ICollection IDictionary.Values => throw new System.NotImplementedException(); + ICollection IDictionary.Values => throw new System.NotImplementedException(); + + public int Count => throw new System.NotImplementedException(); + public bool IsReadOnly => throw new System.NotImplementedException(); + public void Add(string key, string? value) => throw new System.NotImplementedException(); + public void Add(int key, string? value) => throw new System.NotImplementedException(); + public void Add(KeyValuePair item) => throw new System.NotImplementedException(); + public void Add(KeyValuePair item) => throw new System.NotImplementedException(); + public void Clear() => throw new System.NotImplementedException(); + public bool Contains(KeyValuePair item) => throw new System.NotImplementedException(); + public bool Contains(KeyValuePair item) => throw new System.NotImplementedException(); + public bool ContainsKey(string key) => throw new System.NotImplementedException(); + public bool ContainsKey(int key) => throw new System.NotImplementedException(); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException(); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException(); + public bool Remove(string key) => throw new System.NotImplementedException(); + public bool Remove(int key) => throw new System.NotImplementedException(); + public bool Remove(KeyValuePair item) => throw new System.NotImplementedException(); + public bool Remove(KeyValuePair item) => throw new System.NotImplementedException(); + IEnumerator> IEnumerable>.GetEnumerator() => throw new System.NotImplementedException(); + IEnumerator> IEnumerable>.GetEnumerator() => throw new System.NotImplementedException(); + IEnumerator IEnumerable.GetEnumerator() => throw new System.NotImplementedException(); + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task TryGetValue_NotIDictionary_ShouldNotReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + class MyClass + { + public bool TryGetValue(string key, out string? value) + { + value = null; + return false; + } + } + """) + .ValidateAsync(); + } + + [Fact] + public async Task TryGetValue_NonNullableValue_ShouldNotReportDiagnostic() + { + await CreateProjectBuilder() + .WithSourceCode(""" + using System.Collections.Generic; + + class MyDictionary : IDictionary + { + public bool TryGetValue(string key, out string value) + { + value = ""; + return false; + } + + // Other IDictionary members... + public string this[string key] { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); } + public ICollection Keys => throw new System.NotImplementedException(); + public ICollection Values => throw new System.NotImplementedException(); + public int Count => throw new System.NotImplementedException(); + public bool IsReadOnly => throw new System.NotImplementedException(); + public void Add(string key, string value) => throw new System.NotImplementedException(); + public void Add(KeyValuePair item) => throw new System.NotImplementedException(); + public void Clear() => throw new System.NotImplementedException(); + public bool Contains(KeyValuePair item) => throw new System.NotImplementedException(); + public bool ContainsKey(string key) => throw new System.NotImplementedException(); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => throw new System.NotImplementedException(); + public IEnumerator> GetEnumerator() => throw new System.NotImplementedException(); + public bool Remove(string key) => throw new System.NotImplementedException(); + public bool Remove(KeyValuePair item) => throw new System.NotImplementedException(); + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() => throw new System.NotImplementedException(); + } + """) + .ValidateAsync(); + } +}