Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions TUnit.Analyzers.CodeFixers/Base/AssertionRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,37 @@ public abstract class AssertionRewriter : CSharpSyntaxRewriter
{
protected readonly SemanticModel SemanticModel;
protected abstract string FrameworkName { get; }


/// <summary>
/// Tracks whether the current method has ref, out, or in parameters.
/// Methods with these parameters cannot be async, so assertions must use .Wait() instead of await.
/// </summary>
private bool _currentMethodHasRefOutInParameters;

protected AssertionRewriter(SemanticModel semanticModel)
{
SemanticModel = semanticModel;
}


public override SyntaxNode? VisitMethodDeclaration(MethodDeclarationSyntax node)
{
// Track whether this method has ref/out/in parameters
var previousValue = _currentMethodHasRefOutInParameters;
_currentMethodHasRefOutInParameters = node.ParameterList.Parameters.Any(p =>
p.Modifiers.Any(SyntaxKind.RefKeyword) ||
p.Modifiers.Any(SyntaxKind.OutKeyword) ||
p.Modifiers.Any(SyntaxKind.InKeyword));

try
{
return base.VisitMethodDeclaration(node);
}
finally
{
_currentMethodHasRefOutInParameters = previousValue;
}
}

public override SyntaxNode? VisitInvocationExpression(InvocationExpressionSyntax node)
{
var convertedAssertion = ConvertAssertionIfNeeded(node);
Expand Down Expand Up @@ -116,11 +141,8 @@ protected ExpressionSyntax CreateTUnitAssertionWithMessage(
);
}

// Now wrap the entire thing in await: await Assert.That(actualValue).MethodName(args).Because(message)
// Need to add a trailing space after 'await' keyword
var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
.WithTrailingTrivia(SyntaxFactory.Space);
return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation);
// Wrap in await or .Wait() depending on whether the method can be async
return WrapAssertionForAsync(fullInvocation);
}

/// <summary>
Expand Down Expand Up @@ -181,9 +203,31 @@ protected ExpressionSyntax CreateTUnitGenericAssertion(
);
}

// Wrap in await or .Wait() depending on whether the method can be async
return WrapAssertionForAsync(fullInvocation);
}

/// <summary>
/// Wraps an assertion expression in await or .Wait() depending on whether the containing method
/// can be async (methods with ref/out/in parameters cannot be async).
/// </summary>
protected ExpressionSyntax WrapAssertionForAsync(ExpressionSyntax assertionExpression)
{
if (_currentMethodHasRefOutInParameters)
{
// Method has ref/out/in parameters, cannot be async - use .Wait()
var waitAccess = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
assertionExpression,
SyntaxFactory.IdentifierName("Wait")
);
return SyntaxFactory.InvocationExpression(waitAccess, SyntaxFactory.ArgumentList());
}

// Method can be async - use await
var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
.WithTrailingTrivia(SyntaxFactory.Space);
return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation);
return SyntaxFactory.AwaitExpression(awaitKeyword, assertionExpression);
}

protected static bool IsEmptyOrNullMessage(ExpressionSyntax message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ private static string GetMethodKey(MethodDeclarationSyntax node)
return node;
}

// Skip if method has ref, in, or out parameters (async methods cannot have these)
if (node.ParameterList.Parameters.Any(p =>
p.Modifiers.Any(SyntaxKind.RefKeyword) ||
p.Modifiers.Any(SyntaxKind.OutKeyword) ||
p.Modifiers.Any(SyntaxKind.InKeyword)))
{
return node;
}

// Check if method contains await expressions
bool hasAwait = node.DescendantNodes().OfType<AwaitExpressionSyntax>().Any();
if (!hasAwait)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ protected async Task<Document> ConvertCodeAsync(Document document, SyntaxNode? r
{
compilationUnit = MigrationHelpers.AddTUnitUsings(compilationUnit);
}
else
{
// Even if not adding TUnit usings, always add System.Threading.Tasks if there's async code
compilationUnit = MigrationHelpers.AddSystemThreadingTasksUsing(compilationUnit);
}

// Clean up trivia issues that can occur after transformations
compilationUnit = CleanupClassMemberLeadingTrivia(compilationUnit);
Expand Down
94 changes: 50 additions & 44 deletions TUnit.Analyzers.CodeFixers/NUnitMigrationCodeFixProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ protected override bool IsFrameworkAttribute(string attributeName)
{
return attributeName switch
{
"Test" or "TestCase" or "TestCaseSource" or
"Test" or "Theory" or "TestCase" or "TestCaseSource" or
"SetUp" or "TearDown" or "OneTimeSetUp" or "OneTimeTearDown" or
"TestFixture" or "Category" or "Ignore" or "Explicit" or "Apartment" or
"Platform" or "Theory" or "Description" => true,
"Platform" or "Description" => true,
_ => false
};
}
Expand Down Expand Up @@ -846,47 +846,55 @@ private ExpressionSyntax CreateCountAssertion(ExpressionSyntax actualValue, stri
);
}

// Wrap in await
var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
.WithTrailingTrivia(SyntaxFactory.Space);
return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation);
// Wrap in await or .Wait() depending on whether the method can be async
return WrapAssertionForAsync(fullInvocation);
}

/// <summary>
/// Chains a method call onto an existing await expression.
/// Chains a method call onto an existing await expression or .Wait() expression.
/// For example: await Assert.That(x).IsEqualTo(5) becomes await Assert.That(x).IsEqualTo(5).Within(2)
/// </summary>
private ExpressionSyntax ChainMethodCall(ExpressionSyntax baseExpression, string methodName, params ArgumentSyntax[] arguments)
{
// The base expression is an AwaitExpression like: await Assert.That(x).IsEqualTo(5)
// We need to extract the invocation, add .Within(2) to it, and re-wrap in await
ExpressionSyntax innerInvocation;

// The base expression is either:
// 1. An AwaitExpression like: await Assert.That(x).IsEqualTo(5)
// 2. An InvocationExpression like: Assert.That(x).IsEqualTo(5).Wait() (for ref/out methods)
if (baseExpression is AwaitExpressionSyntax awaitExpr)
{
var innerInvocation = awaitExpr.Expression;

// Create the chained method access: Assert.That(x).IsEqualTo(5).Within
var chainedAccess = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
innerInvocation,
SyntaxFactory.IdentifierName(methodName)
);

// Create the invocation: Assert.That(x).IsEqualTo(5).Within(2)
var chainedInvocation = SyntaxFactory.InvocationExpression(
chainedAccess,
arguments.Length > 0
? SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(arguments))
: SyntaxFactory.ArgumentList()
);

// Re-wrap in await
var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
.WithTrailingTrivia(SyntaxFactory.Space);
return SyntaxFactory.AwaitExpression(awaitKeyword, chainedInvocation);
innerInvocation = awaitExpr.Expression;
}

// Fallback: just return the base expression if it's not the expected shape
return baseExpression;
else if (baseExpression is InvocationExpressionSyntax waitInvocation &&
waitInvocation.Expression is MemberAccessExpressionSyntax waitAccess &&
waitAccess.Name.Identifier.Text == "Wait")
{
// Extract the expression before .Wait()
innerInvocation = waitAccess.Expression;
}
else
{
// Fallback: just return the base expression if it's not the expected shape
return baseExpression;
}

// Create the chained method access: Assert.That(x).IsEqualTo(5).Within
var chainedAccess = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
innerInvocation,
SyntaxFactory.IdentifierName(methodName)
);

// Create the invocation: Assert.That(x).IsEqualTo(5).Within(2)
var chainedInvocation = SyntaxFactory.InvocationExpression(
chainedAccess,
arguments.Length > 0
? SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(arguments))
: SyntaxFactory.ArgumentList()
);

// Re-wrap in await or .Wait() depending on method context
return WrapAssertionForAsync(chainedInvocation);
}

private ExpressionSyntax? ConvertClassicAssertion(InvocationExpressionSyntax invocation, string methodName)
Expand Down Expand Up @@ -1046,19 +1054,19 @@ private ExpressionSyntax ConvertNUnitThrows(InvocationExpressionSyntax invocatio
)
);

return SyntaxFactory.AwaitExpression(throwsAsyncInvocation);
return WrapAssertionForAsync(throwsAsyncInvocation);
}

// Handle non-generic constraint-based form: Assert.Throws(constraint, () => ...) or Assert.ThrowsAsync(constraint, () => ...)
// where constraint is typically Is.TypeOf(typeof(T))
if (invocation.ArgumentList.Arguments.Count >= 2)
{
var constraint = invocation.ArgumentList.Arguments[0].Expression;
var action = invocation.ArgumentList.Arguments[1].Expression;

// Try to extract the exception type from the constraint
var exceptionType = TryExtractTypeFromConstraint(constraint);

if (exceptionType != null)
{
// Convert to generic ThrowsAsync form: Assert.ThrowsAsync<T>(() => ...)
Expand All @@ -1080,7 +1088,7 @@ private ExpressionSyntax ConvertNUnitThrows(InvocationExpressionSyntax invocatio
)
);

return SyntaxFactory.AwaitExpression(throwsAsyncInvocation);
return WrapAssertionForAsync(throwsAsyncInvocation);
}
}
}
Expand Down Expand Up @@ -1126,12 +1134,10 @@ private ExpressionSyntax ConvertDoesNotThrow(SeparatedSyntaxList<ArgumentSyntax>
SyntaxFactory.ArgumentList()
);

// Wrap in await
var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
.WithTrailingTrivia(SyntaxFactory.Space);
return SyntaxFactory.AwaitExpression(awaitKeyword, throwsNothingInvocation);
// Wrap in await or .Wait() depending on method context
return WrapAssertionForAsync(throwsNothingInvocation);
}

/// <summary>
/// Attempts to extract the exception type from NUnit constraint expressions like Is.TypeOf(typeof(T)).
/// Returns null if the type cannot be extracted.
Expand Down Expand Up @@ -1189,7 +1195,7 @@ private ExpressionSyntax CreatePassAssertion(SeparatedSyntaxList<ArgumentSyntax>
: SyntaxFactory.ArgumentList()
);

return SyntaxFactory.AwaitExpression(passInvocation);
return WrapAssertionForAsync(passInvocation);
}

private ExpressionSyntax CreateFailAssertion(SeparatedSyntaxList<ArgumentSyntax> arguments)
Expand Down
24 changes: 18 additions & 6 deletions TUnit.Analyzers.CodeFixers/XUnitMigrationCodeFixProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,9 @@ private ExpressionSyntax ConvertThrowsAny(InvocationExpressionSyntax invocation,
)
);

return SyntaxFactory.AwaitExpression(invocationExpression);
var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
.WithTrailingTrivia(SyntaxFactory.Space);
return SyntaxFactory.AwaitExpression(awaitKeyword, invocationExpression);
}

return CreateTUnitAssertion("Throws", invocation.ArgumentList.Arguments[0].Expression);
Expand Down Expand Up @@ -953,7 +955,9 @@ private ExpressionSyntax ConvertIsNotType(InvocationExpressionSyntax invocation,
);

var fullInvocation = SyntaxFactory.InvocationExpression(methodAccess, SyntaxFactory.ArgumentList());
return SyntaxFactory.AwaitExpression(fullInvocation);
var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
.WithTrailingTrivia(SyntaxFactory.Space);
return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation);
}

return CreateTUnitAssertion("IsNotTypeOf", invocation.ArgumentList.Arguments[0].Expression);
Expand Down Expand Up @@ -985,7 +989,9 @@ private ExpressionSyntax ConvertThrows(InvocationExpressionSyntax invocation, Si
)
);

return SyntaxFactory.AwaitExpression(invocationExpression);
var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
.WithTrailingTrivia(SyntaxFactory.Space);
return SyntaxFactory.AwaitExpression(awaitKeyword, invocationExpression);
}

// Fallback
Expand Down Expand Up @@ -1018,7 +1024,9 @@ private ExpressionSyntax ConvertThrowsAsync(InvocationExpressionSyntax invocatio
)
);

return SyntaxFactory.AwaitExpression(invocationExpression);
var awaitKeyword2 = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
.WithTrailingTrivia(SyntaxFactory.Space);
return SyntaxFactory.AwaitExpression(awaitKeyword2, invocationExpression);
}

return CreateTUnitAssertion("ThrowsAsync", invocation.ArgumentList.Arguments[0].Expression);
Expand Down Expand Up @@ -1057,7 +1065,9 @@ private ExpressionSyntax ConvertIsType(InvocationExpressionSyntax invocation, Si
);

var fullInvocation = SyntaxFactory.InvocationExpression(methodAccess, SyntaxFactory.ArgumentList());
return SyntaxFactory.AwaitExpression(fullInvocation);
var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
.WithTrailingTrivia(SyntaxFactory.Space);
return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation);
}

return CreateTUnitAssertion("IsTypeOf", invocation.ArgumentList.Arguments[0].Expression);
Expand Down Expand Up @@ -1096,7 +1106,9 @@ private ExpressionSyntax ConvertIsAssignableFrom(InvocationExpressionSyntax invo
);

var fullInvocation = SyntaxFactory.InvocationExpression(methodAccess, SyntaxFactory.ArgumentList());
return SyntaxFactory.AwaitExpression(fullInvocation);
var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
.WithTrailingTrivia(SyntaxFactory.Space);
return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation);
}

return CreateTUnitAssertion("IsAssignableTo", invocation.ArgumentList.Arguments[0].Expression);
Expand Down
Loading
Loading