diff --git a/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs b/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs index 42b5de9480..c749977f9f 100644 --- a/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs +++ b/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs @@ -384,6 +384,7 @@ private static CompilationUnitSyntax AddTwoPhaseFailureTodoComments( /// /// Removes excessive blank lines at the start of class members (after opening brace). /// This can occur after removing members like ITestOutputHelper fields/properties. + /// Preserves preprocessor directives (#region, #if, #endif, etc.) and associated comments. /// protected static CompilationUnitSyntax CleanupClassMemberLeadingTrivia(CompilationUnitSyntax root) { @@ -398,12 +399,63 @@ protected static CompilationUnitSyntax CleanupClassMemberLeadingTrivia(Compilati var firstMember = classToFix.Members.First(); var leadingTrivia = firstMember.GetLeadingTrivia(); - // Keep only indentation (whitespace), remove all newlines - var triviaToKeep = leadingTrivia - .Where(t => !t.IsKind(SyntaxKind.EndOfLineTrivia)) - .Where(t => t.IsKind(SyntaxKind.WhitespaceTrivia) || - (!t.IsKind(SyntaxKind.WhitespaceTrivia) && !t.IsKind(SyntaxKind.EndOfLineTrivia))) - .ToList(); + // Build the new trivia list, preserving preprocessor directives and their context + var triviaToKeep = new List(); + var consecutiveNewlines = 0; + var lastWasPreprocessorOrComment = false; + + foreach (var trivia in leadingTrivia) + { + // Always preserve preprocessor directives + if (IsPreprocessorDirective(trivia)) + { + triviaToKeep.Add(trivia); + consecutiveNewlines = 0; + lastWasPreprocessorOrComment = true; + continue; + } + + // Always preserve comments + if (trivia.IsKind(SyntaxKind.SingleLineCommentTrivia) || + trivia.IsKind(SyntaxKind.MultiLineCommentTrivia) || + trivia.IsKind(SyntaxKind.SingleLineDocumentationCommentTrivia) || + trivia.IsKind(SyntaxKind.MultiLineDocumentationCommentTrivia)) + { + triviaToKeep.Add(trivia); + consecutiveNewlines = 0; + lastWasPreprocessorOrComment = true; + continue; + } + + // Preserve whitespace (indentation) + if (trivia.IsKind(SyntaxKind.WhitespaceTrivia)) + { + triviaToKeep.Add(trivia); + continue; + } + + // Handle newlines: allow one newline after preprocessor/comment, remove excessive newlines + if (trivia.IsKind(SyntaxKind.EndOfLineTrivia)) + { + consecutiveNewlines++; + + // Keep the first newline after a preprocessor directive or comment + // This ensures proper formatting like: + // #region Test + // [Test] + if (lastWasPreprocessorOrComment && consecutiveNewlines == 1) + { + triviaToKeep.Add(trivia); + } + // Otherwise skip excessive newlines at the start + lastWasPreprocessorOrComment = false; + continue; + } + + // Preserve any other trivia (structured trivia, etc.) + triviaToKeep.Add(trivia); + lastWasPreprocessorOrComment = false; + } var newFirstMember = firstMember.WithLeadingTrivia(triviaToKeep); var updatedClass = classToFix.ReplaceNode(firstMember, newFirstMember); @@ -413,9 +465,29 @@ protected static CompilationUnitSyntax CleanupClassMemberLeadingTrivia(Compilati return currentRoot; } + /// + /// Checks if a trivia is a preprocessor directive. + /// + private static bool IsPreprocessorDirective(SyntaxTrivia trivia) + { + return trivia.IsKind(SyntaxKind.RegionDirectiveTrivia) || + trivia.IsKind(SyntaxKind.EndRegionDirectiveTrivia) || + trivia.IsKind(SyntaxKind.IfDirectiveTrivia) || + trivia.IsKind(SyntaxKind.ElseDirectiveTrivia) || + trivia.IsKind(SyntaxKind.ElifDirectiveTrivia) || + trivia.IsKind(SyntaxKind.EndIfDirectiveTrivia) || + trivia.IsKind(SyntaxKind.DefineDirectiveTrivia) || + trivia.IsKind(SyntaxKind.UndefDirectiveTrivia) || + trivia.IsKind(SyntaxKind.PragmaWarningDirectiveTrivia) || + trivia.IsKind(SyntaxKind.PragmaChecksumDirectiveTrivia) || + trivia.IsKind(SyntaxKind.NullableDirectiveTrivia); + } + /// /// Finds a class with excessive leading trivia on its first member. /// Returns null if no such class exists. + /// Only considers trivia "excessive" if there are multiple consecutive newlines + /// without preprocessor directives or comments between them. /// private static ClassDeclarationSyntax? FindClassWithExcessiveLeadingTrivia(CompilationUnitSyntax root) { @@ -425,7 +497,27 @@ protected static CompilationUnitSyntax CleanupClassMemberLeadingTrivia(Compilati .FirstOrDefault(c => { var leadingTrivia = c.Members.First().GetLeadingTrivia(); - return leadingTrivia.Any(t => t.IsKind(SyntaxKind.EndOfLineTrivia)); + + // Check for excessive newlines (more than one consecutive newline without meaningful trivia) + var consecutiveNewlines = 0; + foreach (var trivia in leadingTrivia) + { + if (trivia.IsKind(SyntaxKind.EndOfLineTrivia)) + { + consecutiveNewlines++; + if (consecutiveNewlines > 1) + { + return true; // Excessive newlines found + } + } + else if (!trivia.IsKind(SyntaxKind.WhitespaceTrivia)) + { + // Non-whitespace, non-newline trivia resets the counter + consecutiveNewlines = 0; + } + } + + return false; }); } diff --git a/TUnit.Analyzers.CodeFixers/Base/TwoPhase/ConversionPlan.cs b/TUnit.Analyzers.CodeFixers/Base/TwoPhase/ConversionPlan.cs index ecef844479..9dd2aa3d5b 100644 --- a/TUnit.Analyzers.CodeFixers/Base/TwoPhase/ConversionPlan.cs +++ b/TUnit.Analyzers.CodeFixers/Base/TwoPhase/ConversionPlan.cs @@ -345,6 +345,17 @@ public class MethodSignatureChange : ConversionTarget /// Whether to make the method public (for lifecycle methods) /// public bool MakePublic { get; init; } + + /// + /// Whether to wrap the return type in Task<T> (for non-void, non-Task return types) + /// + public bool WrapReturnTypeInTask { get; init; } + + /// + /// The original return type to wrap (e.g., "object", "int") + /// Only set when WrapReturnTypeInTask is true. + /// + public string? OriginalReturnType { get; init; } } /// diff --git a/TUnit.Analyzers.CodeFixers/Base/TwoPhase/MigrationAnalyzer.cs b/TUnit.Analyzers.CodeFixers/Base/TwoPhase/MigrationAnalyzer.cs index 08e323bbdf..7ddbc48bad 100644 --- a/TUnit.Analyzers.CodeFixers/Base/TwoPhase/MigrationAnalyzer.cs +++ b/TUnit.Analyzers.CodeFixers/Base/TwoPhase/MigrationAnalyzer.cs @@ -521,10 +521,17 @@ protected virtual CompilationUnitSyntax AnalyzeMethodSignatures(CompilationUnitS try { + var returnTypeText = method.ReturnType.ToString(); + + // Determine what return type change is needed + var (changeReturnTypeToTask, wrapReturnTypeInTask, originalReturnType) = AnalyzeReturnTypeForAsync(returnTypeText); + var change = new MethodSignatureChange { AddAsync = true, - ChangeReturnTypeToTask = method.ReturnType.ToString() == "void", + ChangeReturnTypeToTask = changeReturnTypeToTask, + WrapReturnTypeInTask = wrapReturnTypeInTask, + OriginalReturnType = originalReturnType, OriginalText = $"{method.ReturnType} {method.Identifier}" }; @@ -548,6 +555,44 @@ protected virtual CompilationUnitSyntax AnalyzeMethodSignatures(CompilationUnitS return currentRoot; } + /// + /// Analyzes the return type to determine what changes are needed for async conversion. + /// + /// + /// A tuple of (changeReturnTypeToTask, wrapReturnTypeInTask, originalReturnType): + /// - changeReturnTypeToTask: true if return type is void and should become Task + /// - wrapReturnTypeInTask: true if return type is non-void, non-Task and should become Task<T> + /// - originalReturnType: the original return type to wrap (only set when wrapReturnTypeInTask is true) + /// + private static (bool changeReturnTypeToTask, bool wrapReturnTypeInTask, string? originalReturnType) AnalyzeReturnTypeForAsync(string returnTypeText) + { + // void → Task + if (returnTypeText == "void") + { + return (true, false, null); + } + + // Already Task or Task → no change needed + if (returnTypeText == "Task" || + returnTypeText.StartsWith("Task<") || + returnTypeText.StartsWith("System.Threading.Tasks.Task")) + { + return (false, false, null); + } + + // Already ValueTask or ValueTask → no change needed (async already works with ValueTask) + if (returnTypeText == "ValueTask" || + returnTypeText.StartsWith("ValueTask<") || + returnTypeText.StartsWith("System.Threading.Tasks.ValueTask")) + { + return (false, false, null); + } + + // Non-void, non-Task return type → wrap in Task + // e.g., object → Task, int → Task + return (false, true, returnTypeText); + } + /// /// Determines which usings to add and remove. /// diff --git a/TUnit.Analyzers.CodeFixers/Base/TwoPhase/MigrationTransformer.cs b/TUnit.Analyzers.CodeFixers/Base/TwoPhase/MigrationTransformer.cs index ed1f7a42ed..d9bb94f30c 100644 --- a/TUnit.Analyzers.CodeFixers/Base/TwoPhase/MigrationTransformer.cs +++ b/TUnit.Analyzers.CodeFixers/Base/TwoPhase/MigrationTransformer.cs @@ -418,6 +418,19 @@ private CompilationUnitSyntax TransformMethodSignatures(CompilationUnitSyntax ro newMethod = newMethod.WithReturnType(taskType); } + // Wrap return type in Task if needed (non-void, non-Task return type) + if (change.WrapReturnTypeInTask && !string.IsNullOrEmpty(change.OriginalReturnType)) + { + // Build Task + var taskGenericType = SyntaxFactory.GenericName( + SyntaxFactory.Identifier("Task"), + SyntaxFactory.TypeArgumentList( + SyntaxFactory.SingletonSeparatedList( + SyntaxFactory.ParseTypeName(change.OriginalReturnType)))) + .WithTrailingTrivia(SyntaxFactory.Space); + newMethod = newMethod.WithReturnType(taskGenericType); + } + // Change ValueTask to Task if needed (for IAsyncLifetime.InitializeAsync → IAsyncInitializer) if (change.ChangeValueTaskToTask && method.ReturnType.ToString() == "ValueTask") { diff --git a/TUnit.Analyzers.CodeFixers/TwoPhase/NUnitTwoPhaseAnalyzer.cs b/TUnit.Analyzers.CodeFixers/TwoPhase/NUnitTwoPhaseAnalyzer.cs index 5feb1de6af..4bcfe745de 100644 --- a/TUnit.Analyzers.CodeFixers/TwoPhase/NUnitTwoPhaseAnalyzer.cs +++ b/TUnit.Analyzers.CodeFixers/TwoPhase/NUnitTwoPhaseAnalyzer.cs @@ -37,14 +37,15 @@ public class NUnitTwoPhaseAnalyzer : MigrationAnalyzer "Parallelizable", "NonParallelizable", "Repeat", "Values", "Range", "ValueSource", "Sequential", "Combinatorial", "Platform", - "ExpectedException" + "ExpectedException", "FixtureLifeCycle" }; private static readonly HashSet NUnitRemovableAttributeNames = new() { "TestFixture", // TestFixture is implicit in TUnit "Combinatorial", // TUnit's default behavior is combinatorial - "Sequential" // No direct equivalent - TUnit uses Matrix which is combinatorial by default + "Sequential", // No direct equivalent - TUnit uses Matrix which is combinatorial by default + "FixtureLifeCycle" // TUnit creates new instances by default (like InstancePerTestCase) }; private static readonly HashSet NUnitConditionallyRemovableAttributes = new() @@ -1411,7 +1412,8 @@ protected override bool ShouldRemoveAttribute(AttributeSyntax node) "Platform" => ConvertPlatformAttribute(node), "Apartment" => ConvertApartmentAttribute(node), "ExpectedException" => (null, null), // Handled separately - "Sequential" => (null, null), // No direct equivalent - TODO needed + "Sequential" => (null, null), // No direct equivalent - removed + "FixtureLifeCycle" => (null, null), // TUnit uses instance-per-test by default - removed _ => (null, null) }; diff --git a/TUnit.Analyzers.CodeFixers/TwoPhase/XUnitTwoPhaseAnalyzer.cs b/TUnit.Analyzers.CodeFixers/TwoPhase/XUnitTwoPhaseAnalyzer.cs index 04d4b85592..d0294f1db9 100644 --- a/TUnit.Analyzers.CodeFixers/TwoPhase/XUnitTwoPhaseAnalyzer.cs +++ b/TUnit.Analyzers.CodeFixers/TwoPhase/XUnitTwoPhaseAnalyzer.cs @@ -329,14 +329,44 @@ private bool IsXUnitAssertion(InvocationExpressionSyntax invocation) { // TUnit has Assert.Throws(action) with the same API as xUnit - no conversion needed! // TUnit's Assert.Throws returns TException just like xUnit. - return (AssertionConversionKind.Throws, null, false, null); + // We return the original expression to track it in the plan without modifying it. + if (args.Count < 1) return (AssertionConversionKind.Throws, null, false, null); + + var action = args[0].Expression.ToString(); + + // Get the type argument from the generic method + if (memberAccess.Name is GenericNameSyntax genericName && + genericName.TypeArgumentList.Arguments.Count > 0) + { + var exceptionType = genericName.TypeArgumentList.Arguments[0].ToString(); + // Return the same expression to track it (no actual change) + return (AssertionConversionKind.Throws, $"Assert.Throws<{exceptionType}>({action})", false, null); + } + + // Non-generic Throws + return (AssertionConversionKind.Throws, $"Assert.Throws({action})", false, null); } private (AssertionConversionKind, string?, bool, string?) ConvertThrowsAsync(MemberAccessExpressionSyntax memberAccess, SeparatedSyntaxList args) { // TUnit has Assert.ThrowsAsync(action) with similar API to xUnit - no conversion needed! // TUnit's Assert.ThrowsAsync returns ThrowsAssertion which is awaitable like xUnit's Task. - return (AssertionConversionKind.ThrowsAsync, null, false, null); + // We return the original expression to track it in the plan without modifying it. + if (args.Count < 1) return (AssertionConversionKind.ThrowsAsync, null, false, null); + + var action = args[0].Expression.ToString(); + + // Get the type argument from the generic method + if (memberAccess.Name is GenericNameSyntax genericName && + genericName.TypeArgumentList.Arguments.Count > 0) + { + var exceptionType = genericName.TypeArgumentList.Arguments[0].ToString(); + // ThrowsAsync is already awaited in xUnit, keep the await + return (AssertionConversionKind.ThrowsAsync, $"await Assert.ThrowsAsync<{exceptionType}>({action})", false, null); + } + + // Non-generic ThrowsAsync + return (AssertionConversionKind.ThrowsAsync, $"await Assert.ThrowsAsync({action})", false, null); } private (AssertionConversionKind, string?, bool, string?) ConvertThrowsAny(MemberAccessExpressionSyntax memberAccess, SeparatedSyntaxList args) diff --git a/TUnit.Analyzers/Migrators/Base/MigrationHelpers.cs b/TUnit.Analyzers/Migrators/Base/MigrationHelpers.cs index 9e6025db33..0a2ae89fdf 100644 --- a/TUnit.Analyzers/Migrators/Base/MigrationHelpers.cs +++ b/TUnit.Analyzers/Migrators/Base/MigrationHelpers.cs @@ -189,6 +189,7 @@ public static CompilationUnitSyntax RemoveFrameworkUsings(CompilationUnitSyntax // Preserve leading trivia from the first using directive (may contain license header) var leadingTrivia = compilationUnit.Usings.FirstOrDefault()?.GetLeadingTrivia() ?? default; + // Step 1: Remove file-level usings var usingsToKeep = compilationUnit.Usings .Where(u => { @@ -220,8 +221,83 @@ public static CompilationUnitSyntax RemoveFrameworkUsings(CompilationUnitSyntax } } + // Step 2: Remove usings inside namespace blocks + result = RemoveUsingsFromNamespaces(result, namespacesToRemove); + return result; } + + /// + /// Removes framework usings from inside namespace blocks (both classic and file-scoped namespaces). + /// + private static CompilationUnitSyntax RemoveUsingsFromNamespaces(CompilationUnitSyntax compilationUnit, string[] namespacesToRemove) + { + var currentRoot = compilationUnit; + + // Handle classic namespace declarations (namespace Foo { ... }) + var namespaceDeclarations = currentRoot.DescendantNodes() + .OfType() + .ToList(); + + foreach (var ns in namespaceDeclarations) + { + if (ns.Usings.Count == 0) continue; + + var currentNs = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(n => n.Span == ns.Span); + + if (currentNs == null) continue; + + var namespaceScopedUsingsToKeep = currentNs.Usings + .Where(u => + { + var nameString = u.Name?.ToString() ?? ""; + return !namespacesToRemove.Any(nsToRemove => + nameString == nsToRemove || nameString.StartsWith(nsToRemove + ".")); + }) + .ToList(); + + if (namespaceScopedUsingsToKeep.Count != currentNs.Usings.Count) + { + var updatedNs = currentNs.WithUsings(SyntaxFactory.List(namespaceScopedUsingsToKeep)); + currentRoot = currentRoot.ReplaceNode(currentNs, updatedNs); + } + } + + // Handle file-scoped namespace declarations (namespace Foo;) + var fileScopedNamespaces = currentRoot.DescendantNodes() + .OfType() + .ToList(); + + foreach (var ns in fileScopedNamespaces) + { + if (ns.Usings.Count == 0) continue; + + var currentNs = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(n => n.Span == ns.Span); + + if (currentNs == null) continue; + + var namespaceScopedUsingsToKeep = currentNs.Usings + .Where(u => + { + var nameString = u.Name?.ToString() ?? ""; + return !namespacesToRemove.Any(nsToRemove => + nameString == nsToRemove || nameString.StartsWith(nsToRemove + ".")); + }) + .ToList(); + + if (namespaceScopedUsingsToKeep.Count != currentNs.Usings.Count) + { + var updatedNs = currentNs.WithUsings(SyntaxFactory.List(namespaceScopedUsingsToKeep)); + currentRoot = currentRoot.ReplaceNode(currentNs, updatedNs); + } + } + + return currentRoot; + } /// /// Adds System.Threading.Tasks using directive if the code contains async methods or await expressions.