Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 144 additions & 1 deletion TUnit.Assertions.Analyzers.Tests/IsNotNullAssertionSuppressorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace TUnit.Assertions.Analyzers.Tests;

/// <summary>
/// Tests for the IsNotNullAssertionSuppressor which suppresses nullability warnings
/// (CS8600, CS8602, CS8604, CS8618) for variables after Assert.That(x).IsNotNull().
/// (CS8600, CS8602, CS8604, CS8618, CS8629) for variables after Assert.That(x).IsNotNull().
///
/// Note: These tests verify that the suppressor correctly identifies and suppresses
/// nullability warnings. The suppressor does not change null-state flow analysis,
Expand All @@ -16,6 +16,7 @@ public class IsNotNullAssertionSuppressorTests
{
private static readonly DiagnosticResult CS8602 = new("CS8602", DiagnosticSeverity.Warning);
private static readonly DiagnosticResult CS8604 = new("CS8604", DiagnosticSeverity.Warning);
private static readonly DiagnosticResult CS8629 = new("CS8629", DiagnosticSeverity.Warning);

[Test]
public async Task Suppresses_CS8602_After_IsNotNull_Assertion()
Expand Down Expand Up @@ -422,4 +423,146 @@ await AnalyzerTestHelpers
.WithCompilerDiagnostics(CompilerDiagnostics.Warnings)
.RunAsync();
}

[Test]
public async Task Suppresses_CS8629_After_IsNotNull_Assertion_On_Simple_Nullable_Value_Type()
{
const string code = """
#nullable enable
using System.Threading.Tasks;
using TUnit.Assertions;
using TUnit.Assertions.Extensions;

public class MyTests
{
public async Task TestMethod()
{
int? nullableInt = GetNullableInt();

await Assert.That(nullableInt).IsNotNull();

// This would normally produce CS8629: Nullable value type may be null
// But the suppressor should suppress it after IsNotNull assertion
int value = {|#0:nullableInt|}.Value;
}

private int? GetNullableInt() => 42;
}
""";

await AnalyzerTestHelpers
.CreateSuppressorTest<IsNotNullAssertionSuppressor>(code)
.IgnoringDiagnostics("CS1591")
.WithSpecificDiagnostics(CS8629)
.WithExpectedDiagnosticsResults(CS8629.WithLocation(0).WithIsSuppressed(true))
.WithCompilerDiagnostics(CompilerDiagnostics.Warnings)
.RunAsync();
}

[Test]
public async Task Does_Not_Suppress_CS8629_Without_IsNotNull_Assertion()
{
const string code = """
#nullable enable
using System.Threading.Tasks;
using TUnit.Assertions;
using TUnit.Assertions.Extensions;

public class MyTests
{
public void TestMethod()
{
int? nullableInt = GetNullableInt();

// No IsNotNull assertion here

// This should still produce CS8629 warning
int value = {|#0:nullableInt|}.Value;
}

private int? GetNullableInt() => 42;
}
""";

await AnalyzerTestHelpers
.CreateSuppressorTest<IsNotNullAssertionSuppressor>(code)
.IgnoringDiagnostics("CS1591")
.WithSpecificDiagnostics(CS8629)
.WithExpectedDiagnosticsResults(CS8629.WithLocation(0).WithIsSuppressed(false))
.WithCompilerDiagnostics(CompilerDiagnostics.Warnings)
.RunAsync();
}

[Test]
public async Task Suppresses_CS8629_On_Member_Access_Nullable_Value_Type()
{
const string code = """
#nullable enable
using System.Threading.Tasks;
using TUnit.Assertions;
using TUnit.Assertions.Extensions;

public class MyTests
{
public async Task TestMethod(int? id)
{
var value = new { Id = id };

await Assert.That(value.Id).IsNotNull();

// This would normally produce CS8629: Nullable value type may be null
// But the suppressor should suppress it after IsNotNull assertion on value.Id
int idValue = {|#0:value.Id|}.Value;
}
}
""";

await AnalyzerTestHelpers
.CreateSuppressorTest<IsNotNullAssertionSuppressor>(code)
.IgnoringDiagnostics("CS1591")
.WithSpecificDiagnostics(CS8629)
.WithExpectedDiagnosticsResults(CS8629.WithLocation(0).WithIsSuppressed(true))
.WithCompilerDiagnostics(CompilerDiagnostics.Warnings)
.RunAsync();
}

[Test]
public async Task Suppresses_CS8629_On_Named_Type_Member_Access()
{
const string code = """
#nullable enable
using System.Threading.Tasks;
using TUnit.Assertions;
using TUnit.Assertions.Extensions;

public class MyModel
{
public int? Id { get; set; }
}

public class MyTests
{
public async Task TestMethod()
{
var model = GetModel();

await Assert.That(model.Id).IsNotNull();

// This would normally produce CS8629: Nullable value type may be null
// But the suppressor should suppress it after IsNotNull assertion on model.Id
int idValue = {|#0:model.Id|}.Value;
}

private MyModel GetModel() => new MyModel { Id = 42 };
}
""";

await AnalyzerTestHelpers
.CreateSuppressorTest<IsNotNullAssertionSuppressor>(code)
.IgnoringDiagnostics("CS1591")
.WithSpecificDiagnostics(CS8629)
.WithExpectedDiagnosticsResults(CS8629.WithLocation(0).WithIsSuppressed(true))
.WithCompilerDiagnostics(CompilerDiagnostics.Warnings)
.RunAsync();
}
}
87 changes: 58 additions & 29 deletions TUnit.Assertions.Analyzers/IsNotNullAssertionSuppressor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace TUnit.Assertions.Analyzers;

/// <summary>
/// Suppresses nullability warnings (CS8600, CS8602, CS8604, CS8618) for variables
/// Suppresses nullability warnings (CS8600, CS8602, CS8604, CS8618, CS8629) for variables
/// after they have been asserted as non-null using Assert.That(x).IsNotNull().
///
/// Note: This suppressor only hides the warnings; it does not change the compiler's
Expand Down Expand Up @@ -43,15 +43,15 @@ public override void ReportSuppressions(SuppressionAnalysisContext context)

var semanticModel = context.GetSemanticModel(sourceTree);

// Find the variable being referenced that caused the warning
var identifierName = GetIdentifierFromNode(node);
if (identifierName is null)
// Find the variable/expression being referenced that caused the warning
var targetExpression = GetTargetExpression(node);
if (targetExpression is null)
{
continue;
}

// Check if this variable was previously asserted as non-null
if (WasAssertedNotNull(identifierName, semanticModel, context.CancellationToken))
// Check if this variable/expression was previously asserted as non-null
if (WasAssertedNotNull(targetExpression, semanticModel, context.CancellationToken))
{
Suppress(context, diagnostic);
}
Expand All @@ -63,42 +63,39 @@ private bool IsNullabilityWarning(string diagnosticId)
return diagnosticId is "CS8600" // Converting null literal or possible null value to non-nullable type
or "CS8602" // Dereference of a possibly null reference
or "CS8604" // Possible null reference argument
or "CS8618"; // Non-nullable field/property uninitialized
or "CS8618" // Non-nullable field/property uninitialized
or "CS8629"; // Nullable value type may be null
}

private IdentifierNameSyntax? GetIdentifierFromNode(SyntaxNode node)
private ExpressionSyntax? GetTargetExpression(SyntaxNode node)
{
// The warning might be on the identifier itself or a parent node
// The warning might be on the identifier itself, a member access, or a parent node
return node switch
{
IdentifierNameSyntax identifier => identifier,
MemberAccessExpressionSyntax { Expression: IdentifierNameSyntax identifier } => identifier,
ArgumentSyntax { Expression: IdentifierNameSyntax identifier } => identifier,
_ => node.DescendantNodesAndSelf().OfType<IdentifierNameSyntax>().FirstOrDefault()
MemberAccessExpressionSyntax memberAccess => memberAccess,
ArgumentSyntax { Expression: var expression } => expression,
_ => node.DescendantNodesAndSelf()
.OfType<ExpressionSyntax>()
.FirstOrDefault(e => e is IdentifierNameSyntax or MemberAccessExpressionSyntax)
};
}

private bool WasAssertedNotNull(
IdentifierNameSyntax identifierName,
ExpressionSyntax targetExpression,
SemanticModel semanticModel,
CancellationToken cancellationToken)
{
var symbol = semanticModel.GetSymbolInfo(identifierName, cancellationToken).Symbol;
if (symbol is null)
{
return false;
}

// Find the containing method/block
var containingMethod = identifierName.FirstAncestorOrSelf<MethodDeclarationSyntax>();
var containingMethod = targetExpression.FirstAncestorOrSelf<MethodDeclarationSyntax>();
if (containingMethod is null)
{
return false;
}

// Look for Assert.That(variable).IsNotNull() patterns before this usage
var allStatements = containingMethod.DescendantNodes().OfType<StatementSyntax>().ToList();
var identifierStatement = identifierName.FirstAncestorOrSelf<StatementSyntax>();
var identifierStatement = targetExpression.FirstAncestorOrSelf<StatementSyntax>();

if (identifierStatement is null)
{
Expand All @@ -117,7 +114,7 @@ private bool WasAssertedNotNull(
var statement = allStatements[i];

// Look for await Assert.That(x).IsNotNull() pattern
if (IsNotNullAssertion(statement, symbol, semanticModel, cancellationToken))
if (IsNotNullAssertion(statement, targetExpression, semanticModel, cancellationToken))
{
return true;
}
Expand All @@ -128,7 +125,7 @@ private bool WasAssertedNotNull(

private bool IsNotNullAssertion(
StatementSyntax statement,
ISymbol targetSymbol,
ExpressionSyntax targetExpression,
SemanticModel semanticModel,
CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -161,11 +158,8 @@ private bool IsNotNullAssertion(

var argument = assertThatCall.ArgumentList.Arguments[0].Expression;

// Get the symbol of the argument
var argumentSymbol = semanticModel.GetSymbolInfo(argument, cancellationToken).Symbol;

// Check if it's the same symbol we're looking for
if (SymbolEqualityComparer.Default.Equals(argumentSymbol, targetSymbol))
// Check if the argument matches the target expression
if (ExpressionsMatch(argument, targetExpression, semanticModel, cancellationToken))
{
return true;
}
Expand All @@ -174,6 +168,40 @@ private bool IsNotNullAssertion(
return false;
}

private bool ExpressionsMatch(
ExpressionSyntax assertArgument,
ExpressionSyntax targetExpression,
SemanticModel semanticModel,
CancellationToken cancellationToken)
{
// For simple identifiers, compare using semantic symbols (handles renames, etc.)
if (assertArgument is IdentifierNameSyntax && targetExpression is IdentifierNameSyntax)
{
return SymbolsMatch(assertArgument, targetExpression, semanticModel, cancellationToken);
}

// For member access chains (e.g., value.Id), recursively compare member and receiver
if (assertArgument is MemberAccessExpressionSyntax assertMember &&
targetExpression is MemberAccessExpressionSyntax targetMember)
{
return SymbolsMatch(assertMember, targetMember, semanticModel, cancellationToken) &&
ExpressionsMatch(assertMember.Expression, targetMember.Expression, semanticModel, cancellationToken);
}

return false;
}

private bool SymbolsMatch(
ExpressionSyntax expr1,
ExpressionSyntax expr2,
SemanticModel semanticModel,
CancellationToken cancellationToken)
{
var symbol1 = semanticModel.GetSymbolInfo(expr1, cancellationToken).Symbol;
var symbol2 = semanticModel.GetSymbolInfo(expr2, cancellationToken).Symbol;
return symbol1 is not null && SymbolEqualityComparer.Default.Equals(symbol1, symbol2);
}

private InvocationExpressionSyntax? FindAssertThatInChain(InvocationExpressionSyntax invocation)
{
// Walk up the expression chain looking for Assert.That()
Expand Down Expand Up @@ -230,7 +258,8 @@ private void Suppress(SuppressionAnalysisContext context, Diagnostic diagnostic)
CreateDescriptor("CS8600"),
CreateDescriptor("CS8602"),
CreateDescriptor("CS8604"),
CreateDescriptor("CS8618")
CreateDescriptor("CS8618"),
CreateDescriptor("CS8629")
);

private static SuppressionDescriptor CreateDescriptor(string id)
Expand Down
Loading