diff --git a/TUnit.Analyzers.CodeFixers/Base/AssertionRewriter.cs b/TUnit.Analyzers.CodeFixers/Base/AssertionRewriter.cs
index 002ea123c9..77da7b16a6 100644
--- a/TUnit.Analyzers.CodeFixers/Base/AssertionRewriter.cs
+++ b/TUnit.Analyzers.CodeFixers/Base/AssertionRewriter.cs
@@ -8,12 +8,37 @@ public abstract class AssertionRewriter : CSharpSyntaxRewriter
{
protected readonly SemanticModel SemanticModel;
protected abstract string FrameworkName { get; }
-
+
+ ///
+ /// Tracks whether the current method has ref, out, or in parameters.
+ /// Methods with these parameters cannot be async, so assertions must use .Wait() instead of await.
+ ///
+ private bool _currentMethodHasRefOutInParameters;
+
protected AssertionRewriter(SemanticModel semanticModel)
{
SemanticModel = semanticModel;
}
-
+
+ public override SyntaxNode? VisitMethodDeclaration(MethodDeclarationSyntax node)
+ {
+ // Track whether this method has ref/out/in parameters
+ var previousValue = _currentMethodHasRefOutInParameters;
+ _currentMethodHasRefOutInParameters = node.ParameterList.Parameters.Any(p =>
+ p.Modifiers.Any(SyntaxKind.RefKeyword) ||
+ p.Modifiers.Any(SyntaxKind.OutKeyword) ||
+ p.Modifiers.Any(SyntaxKind.InKeyword));
+
+ try
+ {
+ return base.VisitMethodDeclaration(node);
+ }
+ finally
+ {
+ _currentMethodHasRefOutInParameters = previousValue;
+ }
+ }
+
public override SyntaxNode? VisitInvocationExpression(InvocationExpressionSyntax node)
{
var convertedAssertion = ConvertAssertionIfNeeded(node);
@@ -116,11 +141,8 @@ protected ExpressionSyntax CreateTUnitAssertionWithMessage(
);
}
- // Now wrap the entire thing in await: await Assert.That(actualValue).MethodName(args).Because(message)
- // Need to add a trailing space after 'await' keyword
- var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
- .WithTrailingTrivia(SyntaxFactory.Space);
- return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation);
+ // Wrap in await or .Wait() depending on whether the method can be async
+ return WrapAssertionForAsync(fullInvocation);
}
///
@@ -181,9 +203,31 @@ protected ExpressionSyntax CreateTUnitGenericAssertion(
);
}
+ // Wrap in await or .Wait() depending on whether the method can be async
+ return WrapAssertionForAsync(fullInvocation);
+ }
+
+ ///
+ /// Wraps an assertion expression in await or .Wait() depending on whether the containing method
+ /// can be async (methods with ref/out/in parameters cannot be async).
+ ///
+ protected ExpressionSyntax WrapAssertionForAsync(ExpressionSyntax assertionExpression)
+ {
+ if (_currentMethodHasRefOutInParameters)
+ {
+ // Method has ref/out/in parameters, cannot be async - use .Wait()
+ var waitAccess = SyntaxFactory.MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ assertionExpression,
+ SyntaxFactory.IdentifierName("Wait")
+ );
+ return SyntaxFactory.InvocationExpression(waitAccess, SyntaxFactory.ArgumentList());
+ }
+
+ // Method can be async - use await
var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
.WithTrailingTrivia(SyntaxFactory.Space);
- return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation);
+ return SyntaxFactory.AwaitExpression(awaitKeyword, assertionExpression);
}
protected static bool IsEmptyOrNullMessage(ExpressionSyntax message)
diff --git a/TUnit.Analyzers.CodeFixers/Base/AsyncMethodSignatureRewriter.cs b/TUnit.Analyzers.CodeFixers/Base/AsyncMethodSignatureRewriter.cs
index 4e6de0c9cf..783a5d0021 100644
--- a/TUnit.Analyzers.CodeFixers/Base/AsyncMethodSignatureRewriter.cs
+++ b/TUnit.Analyzers.CodeFixers/Base/AsyncMethodSignatureRewriter.cs
@@ -100,6 +100,15 @@ private static string GetMethodKey(MethodDeclarationSyntax node)
return node;
}
+ // Skip if method has ref, in, or out parameters (async methods cannot have these)
+ if (node.ParameterList.Parameters.Any(p =>
+ p.Modifiers.Any(SyntaxKind.RefKeyword) ||
+ p.Modifiers.Any(SyntaxKind.OutKeyword) ||
+ p.Modifiers.Any(SyntaxKind.InKeyword)))
+ {
+ return node;
+ }
+
// Check if method contains await expressions
bool hasAwait = node.DescendantNodes().OfType().Any();
if (!hasAwait)
diff --git a/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs b/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs
index 3f2c91f9a7..112a15cdda 100644
--- a/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs
+++ b/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs
@@ -93,6 +93,11 @@ protected async Task ConvertCodeAsync(Document document, SyntaxNode? r
{
compilationUnit = MigrationHelpers.AddTUnitUsings(compilationUnit);
}
+ else
+ {
+ // Even if not adding TUnit usings, always add System.Threading.Tasks if there's async code
+ compilationUnit = MigrationHelpers.AddSystemThreadingTasksUsing(compilationUnit);
+ }
// Clean up trivia issues that can occur after transformations
compilationUnit = CleanupClassMemberLeadingTrivia(compilationUnit);
diff --git a/TUnit.Analyzers.CodeFixers/NUnitMigrationCodeFixProvider.cs b/TUnit.Analyzers.CodeFixers/NUnitMigrationCodeFixProvider.cs
index 442329335d..798bec8336 100644
--- a/TUnit.Analyzers.CodeFixers/NUnitMigrationCodeFixProvider.cs
+++ b/TUnit.Analyzers.CodeFixers/NUnitMigrationCodeFixProvider.cs
@@ -63,10 +63,10 @@ protected override bool IsFrameworkAttribute(string attributeName)
{
return attributeName switch
{
- "Test" or "TestCase" or "TestCaseSource" or
+ "Test" or "Theory" or "TestCase" or "TestCaseSource" or
"SetUp" or "TearDown" or "OneTimeSetUp" or "OneTimeTearDown" or
"TestFixture" or "Category" or "Ignore" or "Explicit" or "Apartment" or
- "Platform" or "Theory" or "Description" => true,
+ "Platform" or "Description" => true,
_ => false
};
}
@@ -846,47 +846,55 @@ private ExpressionSyntax CreateCountAssertion(ExpressionSyntax actualValue, stri
);
}
- // Wrap in await
- var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
- .WithTrailingTrivia(SyntaxFactory.Space);
- return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation);
+ // Wrap in await or .Wait() depending on whether the method can be async
+ return WrapAssertionForAsync(fullInvocation);
}
///
- /// Chains a method call onto an existing await expression.
+ /// Chains a method call onto an existing await expression or .Wait() expression.
/// For example: await Assert.That(x).IsEqualTo(5) becomes await Assert.That(x).IsEqualTo(5).Within(2)
///
private ExpressionSyntax ChainMethodCall(ExpressionSyntax baseExpression, string methodName, params ArgumentSyntax[] arguments)
{
- // The base expression is an AwaitExpression like: await Assert.That(x).IsEqualTo(5)
- // We need to extract the invocation, add .Within(2) to it, and re-wrap in await
+ ExpressionSyntax innerInvocation;
+
+ // The base expression is either:
+ // 1. An AwaitExpression like: await Assert.That(x).IsEqualTo(5)
+ // 2. An InvocationExpression like: Assert.That(x).IsEqualTo(5).Wait() (for ref/out methods)
if (baseExpression is AwaitExpressionSyntax awaitExpr)
{
- var innerInvocation = awaitExpr.Expression;
-
- // Create the chained method access: Assert.That(x).IsEqualTo(5).Within
- var chainedAccess = SyntaxFactory.MemberAccessExpression(
- SyntaxKind.SimpleMemberAccessExpression,
- innerInvocation,
- SyntaxFactory.IdentifierName(methodName)
- );
-
- // Create the invocation: Assert.That(x).IsEqualTo(5).Within(2)
- var chainedInvocation = SyntaxFactory.InvocationExpression(
- chainedAccess,
- arguments.Length > 0
- ? SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(arguments))
- : SyntaxFactory.ArgumentList()
- );
-
- // Re-wrap in await
- var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
- .WithTrailingTrivia(SyntaxFactory.Space);
- return SyntaxFactory.AwaitExpression(awaitKeyword, chainedInvocation);
+ innerInvocation = awaitExpr.Expression;
}
-
- // Fallback: just return the base expression if it's not the expected shape
- return baseExpression;
+ else if (baseExpression is InvocationExpressionSyntax waitInvocation &&
+ waitInvocation.Expression is MemberAccessExpressionSyntax waitAccess &&
+ waitAccess.Name.Identifier.Text == "Wait")
+ {
+ // Extract the expression before .Wait()
+ innerInvocation = waitAccess.Expression;
+ }
+ else
+ {
+ // Fallback: just return the base expression if it's not the expected shape
+ return baseExpression;
+ }
+
+ // Create the chained method access: Assert.That(x).IsEqualTo(5).Within
+ var chainedAccess = SyntaxFactory.MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ innerInvocation,
+ SyntaxFactory.IdentifierName(methodName)
+ );
+
+ // Create the invocation: Assert.That(x).IsEqualTo(5).Within(2)
+ var chainedInvocation = SyntaxFactory.InvocationExpression(
+ chainedAccess,
+ arguments.Length > 0
+ ? SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(arguments))
+ : SyntaxFactory.ArgumentList()
+ );
+
+ // Re-wrap in await or .Wait() depending on method context
+ return WrapAssertionForAsync(chainedInvocation);
}
private ExpressionSyntax? ConvertClassicAssertion(InvocationExpressionSyntax invocation, string methodName)
@@ -1046,19 +1054,19 @@ private ExpressionSyntax ConvertNUnitThrows(InvocationExpressionSyntax invocatio
)
);
- return SyntaxFactory.AwaitExpression(throwsAsyncInvocation);
+ return WrapAssertionForAsync(throwsAsyncInvocation);
}
-
+
// Handle non-generic constraint-based form: Assert.Throws(constraint, () => ...) or Assert.ThrowsAsync(constraint, () => ...)
// where constraint is typically Is.TypeOf(typeof(T))
if (invocation.ArgumentList.Arguments.Count >= 2)
{
var constraint = invocation.ArgumentList.Arguments[0].Expression;
var action = invocation.ArgumentList.Arguments[1].Expression;
-
+
// Try to extract the exception type from the constraint
var exceptionType = TryExtractTypeFromConstraint(constraint);
-
+
if (exceptionType != null)
{
// Convert to generic ThrowsAsync form: Assert.ThrowsAsync(() => ...)
@@ -1080,7 +1088,7 @@ private ExpressionSyntax ConvertNUnitThrows(InvocationExpressionSyntax invocatio
)
);
- return SyntaxFactory.AwaitExpression(throwsAsyncInvocation);
+ return WrapAssertionForAsync(throwsAsyncInvocation);
}
}
}
@@ -1126,12 +1134,10 @@ private ExpressionSyntax ConvertDoesNotThrow(SeparatedSyntaxList
SyntaxFactory.ArgumentList()
);
- // Wrap in await
- var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
- .WithTrailingTrivia(SyntaxFactory.Space);
- return SyntaxFactory.AwaitExpression(awaitKeyword, throwsNothingInvocation);
+ // Wrap in await or .Wait() depending on method context
+ return WrapAssertionForAsync(throwsNothingInvocation);
}
-
+
///
/// Attempts to extract the exception type from NUnit constraint expressions like Is.TypeOf(typeof(T)).
/// Returns null if the type cannot be extracted.
@@ -1189,7 +1195,7 @@ private ExpressionSyntax CreatePassAssertion(SeparatedSyntaxList
: SyntaxFactory.ArgumentList()
);
- return SyntaxFactory.AwaitExpression(passInvocation);
+ return WrapAssertionForAsync(passInvocation);
}
private ExpressionSyntax CreateFailAssertion(SeparatedSyntaxList arguments)
diff --git a/TUnit.Analyzers.CodeFixers/XUnitMigrationCodeFixProvider.cs b/TUnit.Analyzers.CodeFixers/XUnitMigrationCodeFixProvider.cs
index b801769d13..d93b47ccc2 100644
--- a/TUnit.Analyzers.CodeFixers/XUnitMigrationCodeFixProvider.cs
+++ b/TUnit.Analyzers.CodeFixers/XUnitMigrationCodeFixProvider.cs
@@ -908,7 +908,9 @@ private ExpressionSyntax ConvertThrowsAny(InvocationExpressionSyntax invocation,
)
);
- return SyntaxFactory.AwaitExpression(invocationExpression);
+ var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
+ .WithTrailingTrivia(SyntaxFactory.Space);
+ return SyntaxFactory.AwaitExpression(awaitKeyword, invocationExpression);
}
return CreateTUnitAssertion("Throws", invocation.ArgumentList.Arguments[0].Expression);
@@ -953,7 +955,9 @@ private ExpressionSyntax ConvertIsNotType(InvocationExpressionSyntax invocation,
);
var fullInvocation = SyntaxFactory.InvocationExpression(methodAccess, SyntaxFactory.ArgumentList());
- return SyntaxFactory.AwaitExpression(fullInvocation);
+ var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
+ .WithTrailingTrivia(SyntaxFactory.Space);
+ return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation);
}
return CreateTUnitAssertion("IsNotTypeOf", invocation.ArgumentList.Arguments[0].Expression);
@@ -985,7 +989,9 @@ private ExpressionSyntax ConvertThrows(InvocationExpressionSyntax invocation, Si
)
);
- return SyntaxFactory.AwaitExpression(invocationExpression);
+ var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
+ .WithTrailingTrivia(SyntaxFactory.Space);
+ return SyntaxFactory.AwaitExpression(awaitKeyword, invocationExpression);
}
// Fallback
@@ -1018,7 +1024,9 @@ private ExpressionSyntax ConvertThrowsAsync(InvocationExpressionSyntax invocatio
)
);
- return SyntaxFactory.AwaitExpression(invocationExpression);
+ var awaitKeyword2 = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
+ .WithTrailingTrivia(SyntaxFactory.Space);
+ return SyntaxFactory.AwaitExpression(awaitKeyword2, invocationExpression);
}
return CreateTUnitAssertion("ThrowsAsync", invocation.ArgumentList.Arguments[0].Expression);
@@ -1057,7 +1065,9 @@ private ExpressionSyntax ConvertIsType(InvocationExpressionSyntax invocation, Si
);
var fullInvocation = SyntaxFactory.InvocationExpression(methodAccess, SyntaxFactory.ArgumentList());
- return SyntaxFactory.AwaitExpression(fullInvocation);
+ var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
+ .WithTrailingTrivia(SyntaxFactory.Space);
+ return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation);
}
return CreateTUnitAssertion("IsTypeOf", invocation.ArgumentList.Arguments[0].Expression);
@@ -1096,7 +1106,9 @@ private ExpressionSyntax ConvertIsAssignableFrom(InvocationExpressionSyntax invo
);
var fullInvocation = SyntaxFactory.InvocationExpression(methodAccess, SyntaxFactory.ArgumentList());
- return SyntaxFactory.AwaitExpression(fullInvocation);
+ var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
+ .WithTrailingTrivia(SyntaxFactory.Space);
+ return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation);
}
return CreateTUnitAssertion("IsAssignableTo", invocation.ArgumentList.Arguments[0].Expression);
diff --git a/TUnit.Analyzers.Tests/NUnitMigrationAnalyzerTests.cs b/TUnit.Analyzers.Tests/NUnitMigrationAnalyzerTests.cs
index f5e5481478..31f5e44941 100644
--- a/TUnit.Analyzers.Tests/NUnitMigrationAnalyzerTests.cs
+++ b/TUnit.Analyzers.Tests/NUnitMigrationAnalyzerTests.cs
@@ -2544,6 +2544,110 @@ public void TestMethod()
);
}
+ [Test]
+ public async Task NUnit_Method_With_Ref_Parameter_Not_Converted_To_Async()
+ {
+ // Test that methods with ref parameters use .Wait() instead of await
+ // Since HandleRealized has a ref parameter, it uses .Wait() and doesn't become async
+ // MyTest has no assertions directly, so it doesn't become async either
+ await CodeFixer.VerifyCodeFixAsync(
+ """
+ using NUnit.Framework;
+
+ {|#0:public class MyClass|}
+ {
+ [Test]
+ public void MyTest()
+ {
+ bool realized = false;
+ HandleRealized(this, ref realized);
+ }
+
+ private static void HandleRealized(object sender, ref bool realized)
+ {
+ Assert.That(sender, Is.Not.Null);
+ realized = true;
+ }
+ }
+ """,
+ Verifier.Diagnostic(Rules.NUnitMigration).WithLocation(0),
+ """
+ using TUnit.Core;
+ using TUnit.Assertions;
+ using static TUnit.Assertions.Assert;
+ using TUnit.Assertions.Extensions;
+
+ public class MyClass
+ {
+ [Test]
+ public void MyTest()
+ {
+ bool realized = false;
+ HandleRealized(this, ref realized);
+ }
+
+ private static void HandleRealized(object sender, ref bool realized)
+ {
+ Assert.That(sender).IsNotNull().Wait();
+ realized = true;
+ }
+ }
+ """,
+ ConfigureNUnitTest
+ );
+ }
+
+ [Test]
+ public async Task NUnit_Method_With_Out_Parameter_Not_Converted_To_Async()
+ {
+ await CodeFixer.VerifyCodeFixAsync(
+ """
+ using NUnit.Framework;
+
+ {|#0:public class MyClass|}
+ {
+ [Test]
+ public void MyTest()
+ {
+ TryGetValue("key", out int value);
+ Assert.That(value, Is.EqualTo(42));
+ }
+
+ private static void TryGetValue(string key, out int value)
+ {
+ Assert.That(key, Is.Not.Null);
+ value = 42;
+ }
+ }
+ """,
+ Verifier.Diagnostic(Rules.NUnitMigration).WithLocation(0),
+ """
+ using System.Threading.Tasks;
+ using TUnit.Core;
+ using TUnit.Assertions;
+ using static TUnit.Assertions.Assert;
+ using TUnit.Assertions.Extensions;
+
+ public class MyClass
+ {
+ [Test]
+ public async Task MyTest()
+ {
+ TryGetValue("key", out int value);
+ await Assert.That(value).IsEqualTo(42);
+ }
+
+ private static void TryGetValue(string key, out int value)
+ {
+ Assert.That(key).IsNotNull().Wait();
+ value = 42;
+ }
+ }
+ """,
+ ConfigureNUnitTest
+ );
+ }
+
[Test]
public async Task NUnit_InterfaceImplementation_NotConvertedToAsync()
{
diff --git a/TUnit.Analyzers.Tests/XUnitMigrationAnalyzerTests.cs b/TUnit.Analyzers.Tests/XUnitMigrationAnalyzerTests.cs
index 427b67e032..fe7362552d 100644
--- a/TUnit.Analyzers.Tests/XUnitMigrationAnalyzerTests.cs
+++ b/TUnit.Analyzers.Tests/XUnitMigrationAnalyzerTests.cs
@@ -651,6 +651,7 @@ public void MyTest()
Verifier.Diagnostic(Rules.XunitMigration).WithLocation(0),
"""
using TUnit.Core;
+ using System.Threading.Tasks;
public class MyClass
{
@@ -685,6 +686,7 @@ public void MyTest()
Verifier.Diagnostic(Rules.XunitMigration).WithLocation(0),
"""
using TUnit.Core;
+ using System.Threading.Tasks;
public class MyClass
{
@@ -719,6 +721,7 @@ public void MyTest()
Verifier.Diagnostic(Rules.XunitMigration).WithLocation(0),
"""
using TUnit.Core;
+ using System.Threading.Tasks;
public class MyClass
{
@@ -759,6 +762,7 @@ public void MyTest()
"""
using System;
using TUnit.Core;
+ using System.Threading.Tasks;
public class MyClass
{
@@ -801,6 +805,7 @@ public void MyTest()
using System;
using System.Collections.Generic;
using TUnit.Core;
+ using System.Threading.Tasks;
public class MyClass
{
@@ -844,6 +849,7 @@ public void MyTest()
using System;
using System.Collections.Generic;
using TUnit.Core;
+ using System.Threading.Tasks;
public class MyClass
{
@@ -883,6 +889,7 @@ public void MyTest()
Verifier.Diagnostic(Rules.XunitMigration).WithLocation(0),
"""
using TUnit.Core;
+ using System.Threading.Tasks;
public class MyClass
{
@@ -921,6 +928,7 @@ public void MyTest()
Verifier.Diagnostic(Rules.XunitMigration).WithLocation(0),
"""
using TUnit.Core;
+ using System.Threading.Tasks;
public class MyClass
{
@@ -960,6 +968,7 @@ public void MyTest()
"""
using System;
using TUnit.Core;
+ using System.Threading.Tasks;
public class MyClass
{
diff --git a/TUnit.Analyzers/Migrators/Base/MigrationHelpers.cs b/TUnit.Analyzers/Migrators/Base/MigrationHelpers.cs
index cb34fe429d..805b0651cc 100644
--- a/TUnit.Analyzers/Migrators/Base/MigrationHelpers.cs
+++ b/TUnit.Analyzers/Migrators/Base/MigrationHelpers.cs
@@ -181,17 +181,12 @@ public static CompilationUnitSyntax RemoveFrameworkUsings(CompilationUnitSyntax
return compilationUnit.WithUsings(SyntaxFactory.List(usingsToKeep));
}
- public static CompilationUnitSyntax AddTUnitUsings(CompilationUnitSyntax compilationUnit)
+ ///
+ /// Adds System.Threading.Tasks using directive if the code contains async methods or await expressions.
+ /// This is called unconditionally for all migrations since async methods need the Tasks namespace.
+ ///
+ public static CompilationUnitSyntax AddSystemThreadingTasksUsing(CompilationUnitSyntax compilationUnit)
{
- var tunitUsing = SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("TUnit.Core"));
- // Add namespace using so Assert type name is available for Assert.That(...) syntax
- var assertionsNamespaceUsing = SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("TUnit.Assertions"));
- var assertionsStaticUsing = SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("TUnit.Assertions.Assert"))
- .WithStaticKeyword(SyntaxFactory.Token(SyntaxKind.StaticKeyword));
- var extensionsUsing = SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("TUnit.Assertions.Extensions"));
- // Add System.Threading.Tasks for async Task methods
- var tasksUsing = SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("System.Threading.Tasks"));
-
var existingUsings = compilationUnit.Usings.ToList();
// Add System.Threading.Tasks only if the code has async methods or await expressions
@@ -201,9 +196,28 @@ public static CompilationUnitSyntax AddTUnitUsings(CompilationUnitSyntax compila
if (hasAsyncCode && !existingUsings.Any(u => u.Name?.ToString() == "System.Threading.Tasks"))
{
+ var tasksUsing = SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("System.Threading.Tasks"));
existingUsings.Add(tasksUsing);
+ return compilationUnit.WithUsings(SyntaxFactory.List(existingUsings));
}
+ return compilationUnit;
+ }
+
+ public static CompilationUnitSyntax AddTUnitUsings(CompilationUnitSyntax compilationUnit)
+ {
+ var tunitUsing = SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("TUnit.Core"));
+ // Add namespace using so Assert type name is available for Assert.That(...) syntax
+ var assertionsNamespaceUsing = SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("TUnit.Assertions"));
+ var assertionsStaticUsing = SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("TUnit.Assertions.Assert"))
+ .WithStaticKeyword(SyntaxFactory.Token(SyntaxKind.StaticKeyword));
+ var extensionsUsing = SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("TUnit.Assertions.Extensions"));
+
+ // First add System.Threading.Tasks if needed
+ compilationUnit = AddSystemThreadingTasksUsing(compilationUnit);
+
+ var existingUsings = compilationUnit.Usings.ToList();
+
if (!existingUsings.Any(u => u.Name?.ToString() == "TUnit.Core"))
{
existingUsings.Add(tunitUsing);