Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 138 additions & 12 deletions TUnit.Analyzers.CodeFixers/Base/AssertionRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,22 @@ protected AssertionRewriter(SemanticModel semanticModel)

public override SyntaxNode? VisitInvocationExpression(InvocationExpressionSyntax node)
{
var convertedAssertion = ConvertAssertionIfNeeded(node);
// Wrap the conversion in try-catch to ensure one failing assertion doesn't break
// the conversion of all other assertions in the file
ExpressionSyntax? convertedAssertion;
try
{
convertedAssertion = ConvertAssertionIfNeeded(node);
}
catch (Exception ex) when (ex is InvalidOperationException or ArgumentException or NotSupportedException)
{
// If conversion fails for this specific assertion due to expected issues
// (e.g., invalid syntax, unsupported patterns), skip it and continue.
// This ensures partial conversion is better than no conversion.
// Unexpected exceptions will propagate for debugging.
return base.VisitInvocationExpression(node);
}

if (convertedAssertion != null)
{
var conversionTrivia = convertedAssertion.GetLeadingTrivia();
Expand Down Expand Up @@ -89,12 +104,88 @@ protected ExpressionSyntax CreateTUnitAssertion(
return CreateTUnitAssertionWithMessage(methodName, actualValue, null, additionalArguments);
}

/// <summary>
/// Creates a TUnit collection assertion for enumerable/collection types.
/// </summary>
/// <remarks>
/// Note: We intentionally do NOT cast to IEnumerable&lt;T&gt; because:
/// 1. TUnit's Assert.That&lt;T&gt; overloads generally resolve correctly for arrays and collections
/// 2. Adding explicit casts creates noisy code that users would need to clean up
/// 3. If there's genuine overload ambiguity, users can add the cast manually
/// </remarks>
protected ExpressionSyntax CreateTUnitCollectionAssertion(
string methodName,
ExpressionSyntax collectionValue,
params ArgumentSyntax[] additionalArguments)
{
return CreateTUnitAssertionWithMessage(methodName, collectionValue, null, additionalArguments);
}

/// <summary>
/// Ensures that ValueTask and Task types are properly awaited before being passed to Assert.That().
/// This is needed because TUnit's analyzer (TUnitAssertions0008) requires ValueTask to be awaited.
/// If the expression is already an await expression, it's returned as-is.
/// </summary>
private ExpressionSyntax EnsureTaskTypesAreAwaited(ExpressionSyntax expression)
{
// If already an await expression, no action needed
if (expression is AwaitExpressionSyntax)
{
return expression;
}

// Wrap semantic analysis in try-catch to handle TFM-specific failures
// This prevents AggregateException crashes in multi-target projects
try
{
// Try to get the type of the expression using semantic analysis
var typeInfo = SemanticModel.GetTypeInfo(expression);
if (typeInfo.Type is null || typeInfo.Type.TypeKind == TypeKind.Error)
{
return expression;
}

// Check if the type is ValueTask, ValueTask<T>, Task, or Task<T>
var typeName = typeInfo.Type.ToDisplayString();
var isTaskType = typeName.StartsWith("System.Threading.Tasks.ValueTask") ||
typeName.StartsWith("System.Threading.Tasks.Task") ||
typeName == "System.Threading.Tasks.ValueTask" ||
typeName == "System.Threading.Tasks.Task";

// Also check for the short names (when using directive is present)
if (!isTaskType && typeInfo.Type is INamedTypeSymbol namedType)
{
isTaskType = namedType.Name is "ValueTask" or "Task" &&
namedType.ContainingNamespace?.ToDisplayString() == "System.Threading.Tasks";
}

if (!isTaskType)
{
return expression;
}

// Wrap the expression in an await
var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
.WithTrailingTrivia(SyntaxFactory.Space);
return SyntaxFactory.AwaitExpression(awaitKeyword, expression);
}
catch (Exception ex) when (ex is InvalidOperationException or ArgumentException)
{
// Semantic analysis can fail in some TFM configurations (e.g., type not available
// in one target framework). Return expression unchanged and let the user handle it.
return expression;
}
}

protected ExpressionSyntax CreateTUnitAssertionWithMessage(
string methodName,
ExpressionSyntax actualValue,
ExpressionSyntax? message,
params ArgumentSyntax[] additionalArguments)
{
// Ensure ValueTask/Task types are properly awaited before passing to Assert.That
actualValue = EnsureTaskTypesAreAwaited(actualValue);

// Create Assert.That(actualValue)
var assertThatInvocation = SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
Expand Down Expand Up @@ -154,6 +245,9 @@ protected ExpressionSyntax CreateTUnitGenericAssertion(
TypeSyntax typeArg,
ExpressionSyntax? message)
{
// Ensure ValueTask/Task types are properly awaited before passing to Assert.That
actualValue = EnsureTaskTypesAreAwaited(actualValue);

// Create Assert.That(actualValue)
var assertThatInvocation = SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
Expand Down Expand Up @@ -414,24 +508,40 @@ protected static SyntaxTrivia CreateTodoComment(string message)

/// <summary>
/// Determines if an invocation is a framework assertion method.
/// Uses semantic analysis when available, with syntax-based fallback for resilience across TFMs.
/// IMPORTANT: Prioritizes syntax-based detection for deterministic results across TFMs.
/// This prevents AggregateException crashes in multi-target projects where semantic
/// analysis could produce different results for each target framework.
/// </summary>
protected bool IsFrameworkAssertion(InvocationExpressionSyntax invocation)
{
// Try semantic analysis first
var symbolInfo = SemanticModel.GetSymbolInfo(invocation);
if (symbolInfo.Symbol is IMethodSymbol methodSymbol)
// FIRST: Try syntax-based detection (deterministic across TFMs)
// This ensures consistent behavior for multi-target projects
if (IsFrameworkAssertionBySyntax(invocation))
{
var namespaceName = methodSymbol.ContainingNamespace?.ToDisplayString() ?? "";
if (IsFrameworkAssertionNamespace(namespaceName))
return true;
}

// SECOND: Fall back to semantic analysis for cases where syntax detection fails
// (e.g., aliased Assert types, extension methods, etc.)
try
{
var symbolInfo = SemanticModel.GetSymbolInfo(invocation);
if (symbolInfo.Symbol is IMethodSymbol methodSymbol)
{
return true;
var namespaceName = methodSymbol.ContainingNamespace?.ToDisplayString() ?? "";
if (IsFrameworkAssertionNamespace(namespaceName))
{
return true;
}
}
}
catch (Exception ex) when (ex is InvalidOperationException or ArgumentException)
{
// Semantic analysis can fail in edge cases (e.g., incomplete compilation state).
// That's fine - we already tried syntax-based detection above.
}

// Fallback: Syntax-based detection when semantic analysis fails
// This ensures consistent behavior across TFMs
return IsFrameworkAssertionBySyntax(invocation);
return false;
}

/// <summary>
Expand All @@ -445,12 +555,28 @@ private bool IsFrameworkAssertionBySyntax(InvocationExpressionSyntax invocation)
return false;
}

var targetType = memberAccess.Expression.ToString();
// Extract the simple type name from potentially qualified names
// e.g., "NUnit.Framework.Assert" -> "Assert", "Assert" -> "Assert"
var targetType = ExtractSimpleTypeName(memberAccess.Expression);
var methodName = memberAccess.Name.Identifier.Text;

return IsKnownAssertionTypeBySyntax(targetType, methodName);
}

/// <summary>
/// Extracts the simple type name from an expression.
/// Handles qualified names like "NUnit.Framework.Assert" by returning just "Assert".
/// </summary>
private static string ExtractSimpleTypeName(ExpressionSyntax expression)
{
return expression switch
{
IdentifierNameSyntax identifier => identifier.Identifier.Text,
MemberAccessExpressionSyntax memberAccess => memberAccess.Name.Identifier.Text,
_ => expression.ToString()
};
}

/// <summary>
/// Checks if the target type and method name match known framework assertion patterns.
/// Override in derived classes to provide framework-specific patterns.
Expand Down
36 changes: 26 additions & 10 deletions TUnit.Analyzers.CodeFixers/NUnitMigrationCodeFixProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -660,17 +660,15 @@ private UsingStatementSyntax CreateUsingMultipleStatement(ExpressionStatementSyn

protected override bool IsFrameworkAssertionNamespace(string namespaceName)
{
// Exclude NUnit.Framework.Legacy - ClassicAssert should not be converted
return (namespaceName == "NUnit.Framework" || namespaceName.StartsWith("NUnit.Framework."))
&& namespaceName != "NUnit.Framework.Legacy";
// Include NUnit.Framework.Legacy - ClassicAssert should be converted to TUnit assertions
return namespaceName == "NUnit.Framework" || namespaceName.StartsWith("NUnit.Framework.");
}

protected override bool IsKnownAssertionTypeBySyntax(string targetType, string methodName)
{
// NUnit assertion types that can be detected by syntax
// NOTE: ClassicAssert is NOT included because it's in NUnit.Framework.Legacy namespace
// and should not be auto-converted. The semantic check excludes it properly.
return targetType is "Assert" or "CollectionAssert" or "StringAssert" or "FileAssert" or "DirectoryAssert";
// ClassicAssert is in NUnit.Framework.Legacy and should be converted
return targetType is "Assert" or "ClassicAssert" or "CollectionAssert" or "StringAssert" or "FileAssert" or "DirectoryAssert";
}

protected override ExpressionSyntax? ConvertAssertionIfNeeded(InvocationExpressionSyntax invocation)
Expand Down Expand Up @@ -717,15 +715,33 @@ protected override bool IsKnownAssertionTypeBySyntax(string targetType, string m
}

// Handle classic assertions like Assert.AreEqual, ClassicAssert.AreEqual, etc.
if (invocation.Expression is MemberAccessExpressionSyntax classicMemberAccess &&
classicMemberAccess.Expression is IdentifierNameSyntax { Identifier.Text: "Assert" or "ClassicAssert" })
// Also handles qualified names like NUnit.Framework.Assert.AreEqual
if (invocation.Expression is MemberAccessExpressionSyntax classicMemberAccess)
{
return ConvertClassicAssertion(invocation, classicMemberAccess.Name.Identifier.Text);
var typeName = GetSimpleTypeName(classicMemberAccess.Expression);
if (typeName is "Assert" or "ClassicAssert")
{
return ConvertClassicAssertion(invocation, classicMemberAccess.Name.Identifier.Text);
}
}

return null;
}


/// <summary>
/// Extracts the simple type name from an expression.
/// Handles both simple identifiers and qualified names like "NUnit.Framework.Assert".
/// </summary>
private static string GetSimpleTypeName(ExpressionSyntax expression)
{
return expression switch
{
IdentifierNameSyntax identifier => identifier.Identifier.Text,
MemberAccessExpressionSyntax memberAccess => memberAccess.Name.Identifier.Text,
_ => expression.ToString()
};
}

private ExpressionSyntax ConvertAssertThat(InvocationExpressionSyntax invocation)
{
var arguments = invocation.ArgumentList.Arguments;
Expand Down
Loading
Loading