Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 140 additions & 2 deletions TUnit.Analyzers.CodeFixers/Base/AssertionRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ protected ExpressionSyntax CreateTUnitAssertion(
string methodName,
ExpressionSyntax actualValue,
params ArgumentSyntax[] additionalArguments)
{
return CreateTUnitAssertionWithMessage(methodName, actualValue, null, additionalArguments);
}

protected ExpressionSyntax CreateTUnitAssertionWithMessage(
string methodName,
ExpressionSyntax actualValue,
ExpressionSyntax? message,
params ArgumentSyntax[] additionalArguments)
{
// Create Assert.That(actualValue)
var assertThatInvocation = SyntaxFactory.InvocationExpression(
Expand All @@ -60,11 +69,140 @@ protected ExpressionSyntax CreateTUnitAssertion(
? SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(additionalArguments))
: SyntaxFactory.ArgumentList();

var fullInvocation = SyntaxFactory.InvocationExpression(methodAccess, arguments);
ExpressionSyntax fullInvocation = SyntaxFactory.InvocationExpression(methodAccess, arguments);

// Add .Because(message) if message is provided and non-empty
if (message != null && !IsEmptyOrNullMessage(message))
{
var becauseAccess = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
fullInvocation,
SyntaxFactory.IdentifierName("Because")
);

// Now wrap the entire thing in await: await Assert.That(actualValue).MethodName(args)
fullInvocation = SyntaxFactory.InvocationExpression(
becauseAccess,
SyntaxFactory.ArgumentList(
SyntaxFactory.SingletonSeparatedList(
SyntaxFactory.Argument(message)
)
)
);
}

// Now wrap the entire thing in await: await Assert.That(actualValue).MethodName(args).Because(message)
return SyntaxFactory.AwaitExpression(fullInvocation);
}

private static bool IsEmptyOrNullMessage(ExpressionSyntax message)
{
// Check for null literal
if (message is LiteralExpressionSyntax literal)
{
if (literal.IsKind(SyntaxKind.NullLiteralExpression))
{
return true;
}

// Check for empty string literal
if (literal.IsKind(SyntaxKind.StringLiteralExpression) &&
literal.Token.ValueText == "")
{
return true;
}
}

return false;
}

/// <summary>
/// Extracts the message and any format arguments from an argument list.
/// Format string messages like Assert.AreEqual(5, x, "Expected {0}", x) have args after the message.
/// </summary>
protected static (ExpressionSyntax? message, ArgumentSyntax[]? formatArgs) ExtractMessageWithFormatArgs(
SeparatedSyntaxList<ArgumentSyntax> arguments,
int messageIndex)
{
if (arguments.Count <= messageIndex)
{
return (null, null);
}

var message = arguments[messageIndex].Expression;

// Check if there are additional format arguments after the message
if (arguments.Count > messageIndex + 1)
{
var formatArgs = arguments.Skip(messageIndex + 1).ToArray();
return (message, formatArgs);
}

Comment on lines +133 to +139
Copy link

Copilot AI Jan 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ExtractMessageWithFormatArgs method doesn't validate that format arguments are actually format placeholders before attempting to wrap them in string.Format. If a testing framework allows passing extra parameters that aren't format args (like tolerance values or comparers), this could incorrectly wrap them in string.Format. Consider checking if the message expression is actually a string literal with format placeholders before treating additional arguments as format args.

Suggested change
// Check if there are additional format arguments after the message
if (arguments.Count > messageIndex + 1)
{
var formatArgs = arguments.Skip(messageIndex + 1).ToArray();
return (message, formatArgs);
}
// If there are no additional arguments after the message, there's nothing to format.
if (arguments.Count <= messageIndex + 1)
{
return (message, null);
}
// Only treat trailing arguments as format arguments when the message is a string
// literal that actually contains format placeholders like "{0}", "{1}", etc.
if (message is LiteralExpressionSyntax literal &&
literal.IsKind(SyntaxKind.StringLiteralExpression))
{
var text = literal.Token.ValueText;
var hasPlaceholder = false;
for (var i = 0; i < text.Length - 1; i++)
{
if (text[i] == '{' && char.IsDigit(text[i + 1]))
{
hasPlaceholder = true;
break;
}
}
if (hasPlaceholder)
{
var formatArgs = arguments.Skip(messageIndex + 1).ToArray();
return (message, formatArgs);
}
}
// Message is not a format string; ignore trailing arguments for formatting purposes.

Copilot uses AI. Check for mistakes.
return (message, null);
}

/// <summary>
/// Creates a message expression, wrapping in string.Format if format args are present.
/// </summary>
protected static ExpressionSyntax CreateMessageExpression(
ExpressionSyntax message,
ArgumentSyntax[]? formatArgs)
{
if (formatArgs == null || formatArgs.Length == 0)
{
return message;
}

// Create string.Format(message, arg1, arg2, ...)
var allArgs = new List<ArgumentSyntax>
{
SyntaxFactory.Argument(message)
};
allArgs.AddRange(formatArgs);

return SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.StringKeyword)),
SyntaxFactory.IdentifierName("Format")
),
SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(allArgs))
);
}

/// <summary>
/// Checks if the argument at the given index appears to be a comparer (IComparer, IEqualityComparer).
/// </summary>
protected bool IsLikelyComparerArgument(ArgumentSyntax argument)
{
var typeInfo = SemanticModel.GetTypeInfo(argument.Expression);
if (typeInfo.Type == null) return false;

var typeName = typeInfo.Type.ToDisplayString();

// Check for IComparer, IComparer<T>, IEqualityComparer, IEqualityComparer<T>
if (typeName.Contains("IComparer") || typeName.Contains("IEqualityComparer"))
{
return true;
}

// Check interfaces
if (typeInfo.Type is INamedTypeSymbol namedType)
{
return namedType.AllInterfaces.Any(i =>
i.Name == "IComparer" ||
i.Name == "IEqualityComparer");
}

return false;
}

/// <summary>
/// Creates a TODO comment for unsupported features during migration.
/// </summary>
protected static SyntaxTrivia CreateTodoComment(string message)
{
return SyntaxFactory.Comment($"// TODO: TUnit migration - {message}");
}

protected bool IsFrameworkAssertion(InvocationExpressionSyntax invocation)
{
Expand Down
85 changes: 85 additions & 0 deletions TUnit.Analyzers.CodeFixers/Base/AsyncMethodSignatureRewriter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace TUnit.Analyzers.CodeFixers.Base;

/// <summary>
/// Transforms method signatures that contain await expressions but are not marked as async.
/// Converts void methods to async Task and T-returning methods to async Task&lt;T&gt;.
/// </summary>
public class AsyncMethodSignatureRewriter : CSharpSyntaxRewriter
{
public override SyntaxNode? VisitMethodDeclaration(MethodDeclarationSyntax node)
{
// First, visit children to ensure nested content is processed
node = (MethodDeclarationSyntax)base.VisitMethodDeclaration(node)!;

// Skip if already async or abstract
if (node.Modifiers.Any(SyntaxKind.AsyncKeyword) ||
node.Modifiers.Any(SyntaxKind.AbstractKeyword))
{
return node;
}

// Check if method contains await expressions
bool hasAwait = node.DescendantNodes().OfType<AwaitExpressionSyntax>().Any();
if (!hasAwait)
{
return node;
}

// Convert the return type
var newReturnType = ConvertReturnType(node.ReturnType);

// Add async modifier after access modifiers but before other modifiers (like static)
var newModifiers = InsertAsyncModifier(node.Modifiers);

return node
.WithReturnType(newReturnType)
.WithModifiers(newModifiers);
}

private static TypeSyntax ConvertReturnType(TypeSyntax returnType)
{
// void -> Task
if (returnType is PredefinedTypeSyntax predefined && predefined.Keyword.IsKind(SyntaxKind.VoidKeyword))
{
return SyntaxFactory.ParseTypeName("Task")
.WithLeadingTrivia(returnType.GetLeadingTrivia())
.WithTrailingTrivia(returnType.GetTrailingTrivia());
}

// T -> Task<T>
var innerType = returnType.WithoutTrivia();
return SyntaxFactory.GenericName("Task")
.WithTypeArgumentList(
SyntaxFactory.TypeArgumentList(
SyntaxFactory.SingletonSeparatedList(innerType)))
.WithLeadingTrivia(returnType.GetLeadingTrivia())
.WithTrailingTrivia(returnType.GetTrailingTrivia());
Comment on lines +53 to +60
Copy link

Copilot AI Jan 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conversion logic incorrectly wraps any non-void return type in Task, which will break methods that already return Task or Task<T>. For example, a method returning Task<int> would become async Task<Task<int>>. The method should check if the return type is already a Task or Task<T> before wrapping it.

Copilot uses AI. Check for mistakes.
}

private static SyntaxTokenList InsertAsyncModifier(SyntaxTokenList modifiers)
{
// Find the right position for async (after public/private/etc, before static/virtual/etc)
int insertIndex = 0;

for (int i = 0; i < modifiers.Count; i++)
{
var modifier = modifiers[i];
if (modifier.IsKind(SyntaxKind.PublicKeyword) ||
modifier.IsKind(SyntaxKind.PrivateKeyword) ||
modifier.IsKind(SyntaxKind.ProtectedKeyword) ||
modifier.IsKind(SyntaxKind.InternalKeyword))
{
insertIndex = i + 1;
}
}

var asyncModifier = SyntaxFactory.Token(SyntaxKind.AsyncKeyword)
.WithTrailingTrivia(SyntaxFactory.Space);

return modifiers.Insert(insertIndex, asyncModifier);
}
}
18 changes: 18 additions & 0 deletions TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ protected async Task<Document> ConvertCodeAsync(Document document, SyntaxNode? r
// Framework-specific conversions (also use semantic model while it still matches)
compilationUnit = ApplyFrameworkSpecificConversions(compilationUnit, semanticModel, compilation);

// Fix method signatures that now contain await but aren't marked async
var asyncSignatureRewriter = new AsyncMethodSignatureRewriter();
compilationUnit = (CompilationUnitSyntax)asyncSignatureRewriter.Visit(compilationUnit);

// Remove unnecessary base classes and interfaces
var baseTypeRewriter = CreateBaseTypeRewriter(semanticModel, compilation);
compilationUnit = (CompilationUnitSyntax)baseTypeRewriter.Visit(compilationUnit);
Expand All @@ -69,6 +73,13 @@ protected async Task<Document> ConvertCodeAsync(Document document, SyntaxNode? r
var attributeRewriter = CreateAttributeRewriter(compilation);
compilationUnit = (CompilationUnitSyntax)attributeRewriter.Visit(compilationUnit);

// Ensure [Test] attribute is present when data attributes exist (NUnit-specific)
if (ShouldEnsureTestAttribute())
{
var testAttributeEnsurer = new TestAttributeEnsurer();
compilationUnit = (CompilationUnitSyntax)testAttributeEnsurer.Visit(compilationUnit);
}

// Remove framework usings and add TUnit usings (do this LAST)
compilationUnit = MigrationHelpers.RemoveFrameworkUsings(compilationUnit, FrameworkName);

Expand Down Expand Up @@ -106,6 +117,13 @@ protected async Task<Document> ConvertCodeAsync(Document document, SyntaxNode? r
/// </summary>
protected virtual bool ShouldAddTUnitUsings() => true;

/// <summary>
/// Determines whether to run TestAttributeEnsurer to add [Test] when data attributes exist.
/// Override to return true for NUnit (where [TestCase] alone is valid but TUnit requires [Test] + [Arguments]).
/// Default is false since most frameworks don't need this.
/// </summary>
protected virtual bool ShouldEnsureTestAttribute() => false;

/// <summary>
/// Removes excessive blank lines at the start of class members (after opening brace).
/// This can occur after removing members like ITestOutputHelper fields/properties.
Expand Down
Loading
Loading