diff --git a/ChangeLog.md b/ChangeLog.md index 1eb4cd1d0b..2e40e1e514 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -43,6 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Do not simplify 'default' expression if the type is inferred ([RCS1244](https://github.com/JosefPihrt/Roslynator/blob/main/docs/analyzers/RCS1244.md)) ([#966](https://github.com/josefpihrt/roslynator/pull/966). - Use explicit type from lambda expression ([RCS1008](https://github.com/JosefPihrt/Roslynator/blob/main/docs/analyzers/RCS1008.md)) ([#967](https://github.com/josefpihrt/roslynator/pull/967). - Do not remove constructor if it is decorated with 'UsedImplicitlyAttribute' ([RCS1074](https://github.com/JosefPihrt/Roslynator/blob/main/docs/analyzers/RCS1074.md)) ([#968](https://github.com/josefpihrt/roslynator/pull/968). +- Detect argument null check in the form of `ArgumentNullException.ThrowIfNull` ([RR0025](https://github.com/JosefPihrt/Roslynator/blob/main/docs/refactorings/RR0025.md), [RCS1227](https://github.com/JosefPihrt/Roslynator/blob/main/docs/analyzers/RCS1227.md)) ([#974](https://github.com/josefpihrt/roslynator/pull/974). ----- diff --git a/src/Analyzers/CSharp/Analysis/ValidateArgumentsCorrectlyAnalyzer.cs b/src/Analyzers/CSharp/Analysis/ValidateArgumentsCorrectlyAnalyzer.cs index 70262ff635..f5b4ee9006 100644 --- a/src/Analyzers/CSharp/Analysis/ValidateArgumentsCorrectlyAnalyzer.cs +++ b/src/Analyzers/CSharp/Analysis/ValidateArgumentsCorrectlyAnalyzer.cs @@ -1,6 +1,5 @@ // Copyright (c) Josef Pihrt and Contributors. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; using System.Collections.Immutable; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; @@ -58,7 +57,7 @@ private static void AnalyzeMethodDeclaration(SyntaxNodeAnalysisContext context) int index = -1; for (int i = 0; i < statementCount; i++) { - if (IsNullCheck(statements[i])) + if (ArgumentNullCheckAnalysis.IsArgumentNullCheck(statements[i], context.SemanticModel, context.CancellationToken)) { index++; } @@ -100,11 +99,5 @@ private static void AnalyzeMethodDeclaration(SyntaxNodeAnalysisContext context) DiagnosticRules.ValidateArgumentsCorrectly, Location.Create(body.SyntaxTree, new TextSpan(statements[index + 1].SpanStart, 0))); } - - private static bool IsNullCheck(StatementSyntax statement) - { - return statement.IsKind(SyntaxKind.IfStatement) - && ((IfStatementSyntax)statement).SingleNonBlockStatementOrDefault().IsKind(SyntaxKind.ThrowStatement); - } } } diff --git a/src/Common/ArgumentNullCheckAnalysis.cs b/src/Common/ArgumentNullCheckAnalysis.cs new file mode 100644 index 0000000000..845772f4b2 --- /dev/null +++ b/src/Common/ArgumentNullCheckAnalysis.cs @@ -0,0 +1,123 @@ +// Copyright (c) Josef Pihrt and Contributors. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Roslynator.CSharp.Syntax; + +namespace Roslynator.CSharp +{ + internal readonly struct ArgumentNullCheckAnalysis + { + private ArgumentNullCheckAnalysis(ArgumentNullCheckStyle style, string name, bool success) + { + Style = style; + Name = name; + Success = success; + } + + public ArgumentNullCheckStyle Style { get; } + + public string Name { get; } + + public bool Success { get; } + + public static ArgumentNullCheckAnalysis Create( + StatementSyntax statement, + SemanticModel semanticModel, + CancellationToken cancellationToken = default) + { + return Create(statement, semanticModel, name: null, cancellationToken); + } + + public static ArgumentNullCheckAnalysis Create( + StatementSyntax statement, + SemanticModel semanticModel, + string name, + CancellationToken cancellationToken = default) + { + var style = ArgumentNullCheckStyle.None; + string identifier = null; + var success = false; + + if (statement is IfStatementSyntax ifStatement) + { + if (ifStatement.SingleNonBlockStatementOrDefault() is ThrowStatementSyntax throwStatement + && throwStatement.Expression is ObjectCreationExpressionSyntax objectCreation) + { + NullCheckExpressionInfo nullCheck = SyntaxInfo.NullCheckExpressionInfo( + ifStatement.Condition, + semanticModel, + NullCheckStyles.EqualsToNull | NullCheckStyles.IsNull, + cancellationToken: cancellationToken); + + if (nullCheck.Success) + { + style = ArgumentNullCheckStyle.IfStatement; + + if (nullCheck.Expression is IdentifierNameSyntax identifierName) + { + identifier = identifierName.Identifier.ValueText; + + if (name is null + || string.Equals(name, identifier, StringComparison.Ordinal)) + { + if (semanticModel + .GetSymbol(objectCreation, cancellationToken)? + .ContainingType? + .HasMetadataName(MetadataNames.System_ArgumentNullException) == true) + { + success = true; + } + } + } + } + + return new ArgumentNullCheckAnalysis(style, identifier, success); + } + } + else if (statement is ExpressionStatementSyntax expressionStatement) + { + SimpleMemberInvocationStatementInfo invocationInfo = SyntaxInfo.SimpleMemberInvocationStatementInfo(expressionStatement); + + if (invocationInfo.Success + && string.Equals(invocationInfo.NameText, "ThrowIfNull", StringComparison.Ordinal) + && semanticModel + .GetSymbol(invocationInfo.InvocationExpression, cancellationToken)? + .ContainingType? + .HasMetadataName(MetadataNames.System_ArgumentNullException) == true) + { + style = ArgumentNullCheckStyle.ThrowIfNullMethod; + + if (invocationInfo.Arguments.SingleOrDefault(shouldThrow: false)?.Expression is IdentifierNameSyntax identifierName) + { + identifier = identifierName.Identifier.ValueText; + + if (string.Equals(name, identifier, StringComparison.Ordinal)) + success = true; + } + } + } + + return new ArgumentNullCheckAnalysis(style, identifier, success); + } + + public static bool IsArgumentNullCheck( + StatementSyntax statement, + SemanticModel semanticModel, + CancellationToken cancellationToken = default) + { + return IsArgumentNullCheck(statement, semanticModel, name: null, cancellationToken); + } + + public static bool IsArgumentNullCheck( + StatementSyntax statement, + SemanticModel semanticModel, + string name, + CancellationToken cancellationToken = default) + { + return Create(statement, semanticModel, name, cancellationToken).Success; + } + } +} diff --git a/src/Common/ArgumentNullCheckStyle.cs b/src/Common/ArgumentNullCheckStyle.cs new file mode 100644 index 0000000000..9c973c3d08 --- /dev/null +++ b/src/Common/ArgumentNullCheckStyle.cs @@ -0,0 +1,11 @@ +// Copyright (c) Josef Pihrt and Contributors. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Roslynator.CSharp +{ + internal enum ArgumentNullCheckStyle + { + None, + IfStatement, + ThrowIfNullMethod, + } +} diff --git a/src/Refactorings/CSharp/Refactorings/CheckParameterForNullRefactoring.cs b/src/Refactorings/CSharp/Refactorings/CheckParameterForNullRefactoring.cs index 1fc494c837..b7b544f8d7 100644 --- a/src/Refactorings/CSharp/Refactorings/CheckParameterForNullRefactoring.cs +++ b/src/Refactorings/CSharp/Refactorings/CheckParameterForNullRefactoring.cs @@ -105,11 +105,11 @@ public static bool CanRefactor( foreach (StatementSyntax statement in body.Statements) { - NullCheckExpressionInfo nullCheck = GetNullCheckExpressionInfo(statement, semanticModel, cancellationToken); + ArgumentNullCheckAnalysis nullCheck = ArgumentNullCheckAnalysis.Create(statement, semanticModel, parameter.Identifier.ValueText, cancellationToken); - if (nullCheck.Success) + if (nullCheck.Style != ArgumentNullCheckStyle.None) { - if (string.Equals(((IdentifierNameSyntax)nullCheck.Expression).Identifier.ValueText, parameter.Identifier.ValueText, StringComparison.Ordinal)) + if (nullCheck.Success) return false; } else @@ -132,7 +132,7 @@ public static Task RefactorAsync( SyntaxList statements = body.Statements; int count = statements - .TakeWhile(f => GetNullCheckExpressionInfo(f, semanticModel, cancellationToken).Success) + .TakeWhile(f => ArgumentNullCheckAnalysis.IsArgumentNullCheck(f, semanticModel, cancellationToken)) .Count(); List ifStatements = CreateNullChecks(parameters); @@ -200,40 +200,6 @@ private static List CreateNullChecks(ImmutableArray