Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions TUnit.Analyzers.CodeFixers/Base/AsyncMethodSignatureRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,84 @@ namespace TUnit.Analyzers.CodeFixers.Base;
/// </summary>
public class AsyncMethodSignatureRewriter : CSharpSyntaxRewriter
{
private readonly HashSet<string> _interfaceImplementingMethods;

public AsyncMethodSignatureRewriter() : this(new HashSet<string>())
{
}

public AsyncMethodSignatureRewriter(HashSet<string> interfaceImplementingMethods)
{
_interfaceImplementingMethods = interfaceImplementingMethods;
}

/// <summary>
/// Collects method signatures that implement interface members.
/// This should be called BEFORE syntax modifications while the semantic model is still valid.
/// </summary>
public static HashSet<string> CollectInterfaceImplementingMethods(
CompilationUnitSyntax compilationUnit,
SemanticModel semanticModel)
{
var methods = new HashSet<string>();

foreach (var methodDecl in compilationUnit.DescendantNodes().OfType<MethodDeclarationSyntax>())
{
// Check for explicit interface implementation syntax
if (methodDecl.ExplicitInterfaceSpecifier != null)
{
methods.Add(GetMethodKey(methodDecl));
continue;
}

var methodSymbol = semanticModel.GetDeclaredSymbol(methodDecl);
if (methodSymbol == null)
{
continue;
}

// Check if this method explicitly implements an interface
if (methodSymbol.ExplicitInterfaceImplementations.Length > 0)
{
methods.Add(GetMethodKey(methodDecl));
continue;
}

// Check if this method implicitly implements an interface member
var containingType = methodSymbol.ContainingType;
if (containingType != null)
{
foreach (var iface in containingType.AllInterfaces)
{
foreach (var member in iface.GetMembers().OfType<IMethodSymbol>())
{
var impl = containingType.FindImplementationForInterfaceMember(member);
if (SymbolEqualityComparer.Default.Equals(impl, methodSymbol))
{
methods.Add(GetMethodKey(methodDecl));
break;
}
}
}
}
}

return methods;
}

/// <summary>
/// Gets a unique key for a method declaration based on its signature.
/// This key is stable across syntax tree modifications.
/// </summary>
private static string GetMethodKey(MethodDeclarationSyntax node)
{
// Build a key from class name, method name, and parameter types
var className = node.Ancestors().OfType<TypeDeclarationSyntax>().FirstOrDefault()?.Identifier.Text ?? "";
var methodName = node.Identifier.Text;
var parameters = string.Join(",", node.ParameterList.Parameters.Select(p => p.Type?.ToString() ?? ""));
return $"{className}.{methodName}({parameters})";
}

public override SyntaxNode? VisitMethodDeclaration(MethodDeclarationSyntax node)
{
// First, visit children to ensure nested content is processed
Expand All @@ -29,6 +107,21 @@ public class AsyncMethodSignatureRewriter : CSharpSyntaxRewriter
return node;
}

// Skip methods with ref/out/in parameters (they can't be async)
if (node.ParameterList.Parameters.Any(p =>
p.Modifiers.Any(SyntaxKind.RefKeyword) ||
p.Modifiers.Any(SyntaxKind.OutKeyword) ||
p.Modifiers.Any(SyntaxKind.InKeyword)))
{
return node;
}

// Skip if method implements an interface member (changing return type would break the implementation)
if (ImplementsInterfaceMember(node))
{
return node;
}

// Convert the return type
var newReturnType = ConvertReturnType(node.ReturnType);

Expand All @@ -40,6 +133,19 @@ public class AsyncMethodSignatureRewriter : CSharpSyntaxRewriter
.WithModifiers(newModifiers);
}

private bool ImplementsInterfaceMember(MethodDeclarationSyntax node)
{
// Check for explicit interface implementation syntax (IFoo.Method)
if (node.ExplicitInterfaceSpecifier != null)
{
return true;
}

// Check if this method was identified as an interface implementation
var key = GetMethodKey(node);
return _interfaceImplementingMethods.Contains(key);
}

private static TypeSyntax ConvertReturnType(TypeSyntax returnType)
{
// void -> Task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ protected async Task<Document> ConvertCodeAsync(Document document, SyntaxNode? r

try
{
// IMPORTANT: Collect interface-implementing methods BEFORE any syntax modifications
// while the semantic model is still valid for the original syntax tree
var interfaceImplementingMethods = AsyncMethodSignatureRewriter.CollectInterfaceImplementingMethods(
compilationUnit, semanticModel);

// Convert assertions FIRST (while semantic model still matches the syntax tree)
var assertionRewriter = CreateAssertionRewriter(semanticModel, compilation);
compilationUnit = (CompilationUnitSyntax)assertionRewriter.Visit(compilationUnit);
Expand All @@ -58,7 +63,8 @@ protected async Task<Document> ConvertCodeAsync(Document document, SyntaxNode? r
compilationUnit = ApplyFrameworkSpecificConversions(compilationUnit, semanticModel, compilation);

// Fix method signatures that now contain await but aren't marked async
var asyncSignatureRewriter = new AsyncMethodSignatureRewriter();
// Pass the collected interface methods to avoid converting interface implementations
var asyncSignatureRewriter = new AsyncMethodSignatureRewriter(interfaceImplementingMethods);
compilationUnit = (CompilationUnitSyntax)asyncSignatureRewriter.Visit(compilationUnit);

// Remove unnecessary base classes and interfaces
Expand Down
64 changes: 64 additions & 0 deletions TUnit.Analyzers.Tests/NUnitMigrationAnalyzerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2474,6 +2474,70 @@ public async Task TestMethod()
);
}

[Test]
public async Task NUnit_InterfaceImplementation_NotConvertedToAsync()
{
// Methods that implement interface members should NOT be converted to async
// because that would break the interface implementation contract.
// The interface method contains no NUnit assertions, so no await is added.
// Only the test method (which doesn't implement an interface) gets converted to async.
await CodeFixer.VerifyCodeFixAsync(
"""
using NUnit.Framework;
using System.Threading.Tasks;

public interface ITestRunner
{
void Run();
}

{|#0:public class MyClass|} : ITestRunner
{
[Test]
public void TestMethod()
{
Assert.That(true, Is.True);
}

public void Run()
{
// This implements ITestRunner.Run() and should stay void
var x = 1;
}
}
""",
Verifier.Diagnostic(Rules.NUnitMigration).WithLocation(0),
"""
using System.Threading.Tasks;
using TUnit.Core;
using TUnit.Assertions;
using static TUnit.Assertions.Assert;
using TUnit.Assertions.Extensions;

public interface ITestRunner
{
void Run();
}

public class MyClass : ITestRunner
{
[Test]
public async Task TestMethod()
{
await Assert.That(true).IsTrue();
}

public void Run()
{
// This implements ITestRunner.Run() and should stay void
var x = 1;
}
}
""",
ConfigureNUnitTest
);
}

private static void ConfigureNUnitTest(Verifier.Test test)
{
test.TestState.AdditionalReferences.Add(typeof(NUnit.Framework.TestAttribute).Assembly);
Expand Down
Loading