diff --git a/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs b/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs index 24b696f0c9..42b5de9480 100644 --- a/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs +++ b/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs @@ -5,6 +5,7 @@ using Microsoft.CodeAnalysis.CodeFixes; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using TUnit.Analyzers.CodeFixers.Base.TwoPhase; using TUnit.Analyzers.Migrators.Base; namespace TUnit.Analyzers.CodeFixers.Base; @@ -48,72 +49,317 @@ protected async Task ConvertCodeAsync(Document document, SyntaxNode? r var compilation = semanticModel.Compilation; - try + // Check if this framework supports the new two-phase architecture + var analyzer = CreateTwoPhaseAnalyzer(semanticModel, compilation); + if (analyzer != null) { - // IMPORTANT: Collect interface-implementing methods BEFORE any syntax modifications - // while the semantic model is still valid for the original syntax tree - var interfaceImplementingMethods = AsyncMethodSignatureRewriter.CollectInterfaceImplementingMethods( - compilationUnit, semanticModel); + return await ConvertCodeWithTwoPhaseAsync(document, compilationUnit, analyzer, root); + } - // Convert assertions FIRST (while semantic model still matches the syntax tree) - var assertionRewriter = CreateAssertionRewriter(semanticModel, compilation); - compilationUnit = (CompilationUnitSyntax)assertionRewriter.Visit(compilationUnit); + // Fall back to the legacy rewriter-based approach + return await ConvertCodeWithRewritersAsync(document, compilationUnit, semanticModel, compilation, root); + } - // Framework-specific conversions (also use semantic model while it still matches) - compilationUnit = ApplyFrameworkSpecificConversions(compilationUnit, semanticModel, compilation); + /// + /// Creates a two-phase migration analyzer for this framework. + /// Override in derived classes to enable the two-phase architecture. + /// Returns null to use the legacy rewriter-based approach. + /// + protected virtual MigrationAnalyzer? CreateTwoPhaseAnalyzer(SemanticModel semanticModel, Compilation compilation) + { + return null; + } - // Fix method signatures that now contain await but aren't marked async - // Pass the collected interface methods to avoid converting interface implementations - var asyncSignatureRewriter = new AsyncMethodSignatureRewriter(interfaceImplementingMethods); - compilationUnit = (CompilationUnitSyntax)asyncSignatureRewriter.Visit(compilationUnit); + /// + /// New two-phase architecture: Analyze first (while semantic model valid), then transform (pure syntax). + /// This avoids semantic model staleness issues that plague the rewriter-based approach. + /// + private async Task ConvertCodeWithTwoPhaseAsync( + Document document, + CompilationUnitSyntax compilationUnit, + MigrationAnalyzer analyzer, + SyntaxNode? originalRoot) + { + // Phase 1: Analyze - collect all conversion targets while semantic model is valid + // Returns annotated root (nodes marked for conversion) and the conversion plan + var (annotatedRoot, plan) = analyzer.Analyze(compilationUnit); + + // Phase 2: Transform - apply conversions using only syntax operations + // Uses annotations to find nodes, no semantic model needed + var transformer = new MigrationTransformer(plan, FrameworkName); + var transformedRoot = transformer.Transform(annotatedRoot); + + // Phase 2.5: Apply framework-specific syntax-only transformations + // These use CSharpSyntaxRewriter but don't need the semantic model + transformedRoot = ApplyTwoPhasePostTransformations(transformedRoot); + + // Phase 2.6: Re-check usings after post-transformations + // Post-transformations like NUnitExpectedResultRewriter may introduce async code + // that wasn't present during the initial usings transformation + transformedRoot = MigrationHelpers.AddTUnitUsings(transformedRoot); + + // Final cleanup (pure syntax operations) + transformedRoot = CleanupClassMemberLeadingTrivia(transformedRoot); + transformedRoot = CleanupEndOfFileTrivia(transformedRoot); + transformedRoot = NormalizeLineEndings(transformedRoot, originalRoot!); + + // Add TODO comments for any failures + if (plan.HasFailures) + { + transformedRoot = AddTwoPhaseFailureTodoComments(transformedRoot, plan); + } - // Remove unnecessary base classes and interfaces - var baseTypeRewriter = CreateBaseTypeRewriter(semanticModel, compilation); - compilationUnit = (CompilationUnitSyntax)baseTypeRewriter.Visit(compilationUnit); + return document.WithSyntaxRoot(transformedRoot); + } - // Update lifecycle methods - var lifecycleRewriter = CreateLifecycleRewriter(compilation); - compilationUnit = (CompilationUnitSyntax)lifecycleRewriter.Visit(compilationUnit); + /// + /// Apply framework-specific post-transformations after the main two-phase processing. + /// These should be pure syntax operations that don't require the semantic model. + /// Override in derived classes to add framework-specific transformations. + /// + protected virtual CompilationUnitSyntax ApplyTwoPhasePostTransformations(CompilationUnitSyntax compilationUnit) + { + return compilationUnit; + } + + /// + /// Legacy rewriter-based approach for frameworks that haven't migrated to two-phase yet. + /// + private async Task ConvertCodeWithRewritersAsync( + Document document, + CompilationUnitSyntax compilationUnit, + SemanticModel semanticModel, + Compilation compilation, + SyntaxNode? root) + { + var context = new MigrationContext(); + + // Step 1: Collect interface-implementing methods BEFORE any syntax modifications + // while the semantic model is still valid for the original syntax tree + var interfaceImplementingMethods = TryCollectInterfaceMethods( + compilationUnit, semanticModel, context); + + // Step 2: Convert assertions FIRST (while semantic model still matches the syntax tree) + compilationUnit = TryApplyRewriter( + compilationUnit, + () => CreateAssertionRewriter(semanticModel, compilation), + context, + "AssertionConversion"); + + // Step 3: Framework-specific conversions (also use semantic model while it still matches) + compilationUnit = TryApplyFrameworkSpecific( + compilationUnit, semanticModel, compilation, context); + + // Step 4: Fix method signatures that now contain await but aren't marked async + // Pass the collected interface methods to avoid converting interface implementations + compilationUnit = TryApplyRewriter( + compilationUnit, + () => new AsyncMethodSignatureRewriter(interfaceImplementingMethods), + context, + "AsyncSignatureFix"); + + // Step 5: Remove unnecessary base classes and interfaces + compilationUnit = TryApplyRewriter( + compilationUnit, + () => CreateBaseTypeRewriter(semanticModel, compilation), + context, + "BaseTypeRemoval"); + + // Step 6: Update lifecycle methods + compilationUnit = TryApplyRewriter( + compilationUnit, + () => CreateLifecycleRewriter(compilation), + context, + "LifecycleConversion"); + + // Step 7: Convert attributes + compilationUnit = TryApplyRewriter( + compilationUnit, + () => CreateAttributeRewriter(compilation), + context, + "AttributeConversion"); + + // Step 8: Ensure [Test] attribute is present when data attributes exist (NUnit-specific) + if (ShouldEnsureTestAttribute()) + { + compilationUnit = TryApplyRewriter( + compilationUnit, + () => new TestAttributeEnsurer(), + context, + "TestAttributeEnsurer"); + } - // Convert attributes - var attributeRewriter = CreateAttributeRewriter(compilation); - compilationUnit = (CompilationUnitSyntax)attributeRewriter.Visit(compilationUnit); + // Step 9: Remove framework usings and add TUnit usings (do this LAST) + // These are pure syntax operations with minimal risk + compilationUnit = MigrationHelpers.RemoveFrameworkUsings(compilationUnit, FrameworkName); - // Ensure [Test] attribute is present when data attributes exist (NUnit-specific) - if (ShouldEnsureTestAttribute()) - { - var testAttributeEnsurer = new TestAttributeEnsurer(); - compilationUnit = (CompilationUnitSyntax)testAttributeEnsurer.Visit(compilationUnit); - } + if (ShouldAddTUnitUsings()) + { + compilationUnit = MigrationHelpers.AddTUnitUsings(compilationUnit); + } + else + { + // Even if not adding TUnit usings, always add System.Threading.Tasks if there's async code + compilationUnit = MigrationHelpers.AddSystemThreadingTasksUsing(compilationUnit); + } - // Remove framework usings and add TUnit usings (do this LAST) - compilationUnit = MigrationHelpers.RemoveFrameworkUsings(compilationUnit, FrameworkName); + // Step 10: Clean up trivia issues that can occur after transformations + compilationUnit = CleanupClassMemberLeadingTrivia(compilationUnit); + compilationUnit = CleanupEndOfFileTrivia(compilationUnit); - if (ShouldAddTUnitUsings()) - { - compilationUnit = MigrationHelpers.AddTUnitUsings(compilationUnit); - } - else - { - // Even if not adding TUnit usings, always add System.Threading.Tasks if there's async code - compilationUnit = MigrationHelpers.AddSystemThreadingTasksUsing(compilationUnit); - } + // Normalize line endings to match original document (fixes cross-platform issues) + compilationUnit = NormalizeLineEndings(compilationUnit, root); - // Clean up trivia issues that can occur after transformations - compilationUnit = CleanupClassMemberLeadingTrivia(compilationUnit); - compilationUnit = CleanupEndOfFileTrivia(compilationUnit); + // Add TODO comments for any failures so users know what needs manual attention + if (context.HasFailures) + { + compilationUnit = AddFailureTodoComments(compilationUnit, context); + } - // Normalize line endings to match original document (fixes cross-platform issues) - compilationUnit = NormalizeLineEndings(compilationUnit, root); + // Return the document with updated syntax root, preserving original formatting + return document.WithSyntaxRoot(compilationUnit); + } - // Return the document with updated syntax root, preserving original formatting - return document.WithSyntaxRoot(compilationUnit); + /// + /// Safely collects interface-implementing methods before any syntax modifications. + /// + private static HashSet TryCollectInterfaceMethods( + CompilationUnitSyntax root, + SemanticModel semanticModel, + MigrationContext context) + { + try + { + return AsyncMethodSignatureRewriter.CollectInterfaceImplementingMethods(root, semanticModel); } - catch + catch (Exception ex) { - // If any transformation fails, return the original document unchanged - return document; + context.RecordFailure("CollectInterfaceMethods", ex); + return new HashSet(); + } + } + + /// + /// Safely applies a syntax rewriter, recording any failures to the migration context. + /// + private static CompilationUnitSyntax TryApplyRewriter( + CompilationUnitSyntax root, + Func rewriterFactory, + MigrationContext context, + string stepName) + { + try + { + var rewriter = rewriterFactory(); + return (CompilationUnitSyntax)rewriter.Visit(root); + } + catch (Exception ex) + { + context.RecordFailure(stepName, ex); + return root; // Return unchanged, continue with other steps + } + } + + /// + /// Safely applies framework-specific conversions, recording any failures. + /// + private CompilationUnitSyntax TryApplyFrameworkSpecific( + CompilationUnitSyntax root, + SemanticModel semanticModel, + Compilation compilation, + MigrationContext context) + { + try + { + return ApplyFrameworkSpecificConversions(root, semanticModel, compilation); + } + catch (Exception ex) + { + context.RecordFailure("FrameworkSpecificConversions", ex); + return root; + } + } + + /// + /// Adds TODO comments at the top of the file summarizing migration failures. + /// This helps users identify what needs manual attention. + /// + private static CompilationUnitSyntax AddFailureTodoComments( + CompilationUnitSyntax root, + MigrationContext context) + { + // Group failures by step and create summary comments + var failureSummary = context.Failures + .GroupBy(f => f.Step) + .Select(g => $"// TODO: TUnit migration - {g.Key}: {g.Count()} item(s) could not be converted automatically") + .ToList(); + + if (failureSummary.Count == 0) + { + return root; + } + + // Add header comment + var commentTrivia = new List + { + SyntaxFactory.Comment("// ============================================================"), + SyntaxFactory.EndOfLine("\n"), + SyntaxFactory.Comment("// TUnit Migration: Some items require manual attention"), + SyntaxFactory.EndOfLine("\n") + }; + + // Add failure summary lines + foreach (var summary in failureSummary) + { + commentTrivia.Add(SyntaxFactory.Comment(summary)); + commentTrivia.Add(SyntaxFactory.EndOfLine("\n")); } + + commentTrivia.Add(SyntaxFactory.Comment("// ============================================================")); + commentTrivia.Add(SyntaxFactory.EndOfLine("\n")); + commentTrivia.Add(SyntaxFactory.EndOfLine("\n")); + + var existingTrivia = root.GetLeadingTrivia(); + return root.WithLeadingTrivia(SyntaxFactory.TriviaList(commentTrivia).AddRange(existingTrivia)); + } + + /// + /// Adds TODO comments for failures from the two-phase architecture's ConversionPlan. + /// + private static CompilationUnitSyntax AddTwoPhaseFailureTodoComments( + CompilationUnitSyntax root, + ConversionPlan plan) + { + var failureSummary = plan.Failures + .GroupBy(f => f.Phase) + .Select(g => $"// TODO: TUnit migration - {g.Key}: {g.Count()} item(s) could not be converted automatically") + .ToList(); + + if (failureSummary.Count == 0) + { + return root; + } + + var commentTrivia = new List + { + SyntaxFactory.Comment("// ============================================================"), + SyntaxFactory.EndOfLine("\n"), + SyntaxFactory.Comment("// TUnit Migration: Some items require manual attention"), + SyntaxFactory.EndOfLine("\n") + }; + + foreach (var summary in failureSummary) + { + commentTrivia.Add(SyntaxFactory.Comment(summary)); + commentTrivia.Add(SyntaxFactory.EndOfLine("\n")); + } + + commentTrivia.Add(SyntaxFactory.Comment("// ============================================================")); + commentTrivia.Add(SyntaxFactory.EndOfLine("\n")); + commentTrivia.Add(SyntaxFactory.EndOfLine("\n")); + + var existingTrivia = root.GetLeadingTrivia(); + return root.WithLeadingTrivia(SyntaxFactory.TriviaList(commentTrivia).AddRange(existingTrivia)); } protected abstract AttributeRewriter CreateAttributeRewriter(Compilation compilation); diff --git a/TUnit.Analyzers.CodeFixers/Base/MigrationContext.cs b/TUnit.Analyzers.CodeFixers/Base/MigrationContext.cs new file mode 100644 index 0000000000..390912cc69 --- /dev/null +++ b/TUnit.Analyzers.CodeFixers/Base/MigrationContext.cs @@ -0,0 +1,60 @@ +using Microsoft.CodeAnalysis.Text; + +namespace TUnit.Analyzers.CodeFixers.Base; + +/// +/// Tracks migration progress and failures during code fix execution. +/// Enables partial success by recording what failed without aborting the entire migration. +/// Failures are surfaced via TODO comments in the generated code. +/// +public class MigrationContext +{ + public List Failures { get; } = new(); + + /// + /// Records a failure during a migration step. + /// The failure will be reported as a TODO comment in the migrated file. + /// + public void RecordFailure(string step, Exception ex, TextSpan? span = null) + { + var failure = new MigrationFailure( + Step: step, + Description: ex.Message, + Span: span, + OriginalCode: null, + StackTrace: ex.StackTrace); + + Failures.Add(failure); + } + + /// + /// Records a failure during item-level conversion (e.g., a specific assertion). + /// Includes the original code snippet for context in the TODO comment. + /// + public void RecordItemFailure(string step, string originalCode, Exception ex, TextSpan? span = null) + { + var failure = new MigrationFailure( + Step: step, + Description: ex.Message, + Span: span, + OriginalCode: originalCode, + StackTrace: ex.StackTrace); + + Failures.Add(failure); + } + + /// + /// Returns true if any failures were recorded. + /// + public bool HasFailures => Failures.Count > 0; +} + +/// +/// Represents a single migration failure with context for debugging. +/// +public record MigrationFailure( + string Step, + string Description, + TextSpan? Span = null, + string? OriginalCode = null, + string? StackTrace = null); diff --git a/TUnit.Analyzers.CodeFixers/Base/TwoPhase/ConversionPlan.cs b/TUnit.Analyzers.CodeFixers/Base/TwoPhase/ConversionPlan.cs new file mode 100644 index 0000000000..ecef844479 --- /dev/null +++ b/TUnit.Analyzers.CodeFixers/Base/TwoPhase/ConversionPlan.cs @@ -0,0 +1,504 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace TUnit.Analyzers.CodeFixers.Base.TwoPhase; + +/// +/// Represents the complete conversion plan for a single file. +/// Built during Phase 1 (Analysis) while the semantic model is valid. +/// Applied during Phase 2 (Transformation) using pure syntax operations. +/// +public class ConversionPlan +{ + /// + /// Assertions to convert (e.g., Assert.Equal → Assert.That().IsEqualTo()) + /// + public List Assertions { get; } = new(); + + /// + /// Attributes to convert (e.g., [Fact] → [Test]) + /// + public List Attributes { get; } = new(); + + /// + /// Attributes to remove entirely (e.g., [TestClass]) + /// + public List AttributeRemovals { get; } = new(); + + /// + /// Base types to remove (e.g., IClassFixture<T>) + /// + public List BaseTypeRemovals { get; } = new(); + + /// + /// Base types to add (e.g., IAsyncInitializer from IAsyncLifetime) + /// + public List BaseTypeAdditions { get; } = new(); + + /// + /// Class attributes to add (e.g., ClassDataSource from IClassFixture) + /// + public List ClassAttributeAdditions { get; } = new(); + + /// + /// Method attributes to add (e.g., [Before(Test)] on InitializeAsync) + /// + public List MethodAttributeAdditions { get; } = new(); + + /// + /// Methods that need async/Task added to their signature + /// + public List MethodSignatureChanges { get; } = new(); + + /// + /// Members to remove entirely (e.g., ITestOutputHelper fields) + /// + public List MemberRemovals { get; } = new(); + + /// + /// Constructor parameters to remove (e.g., ITestOutputHelper) + /// + public List ConstructorParameterRemovals { get; } = new(); + + /// + /// Record.Exception conversions to try-catch blocks + /// + public List RecordExceptionConversions { get; } = new(); + + /// + /// Invocations to replace (e.g., _testOutputHelper.WriteLine → Console.WriteLine) + /// + public List InvocationReplacements { get; } = new(); + + /// + /// TheoryData conversions (TheoryData<T> → IEnumerable<T>) + /// + public List TheoryDataConversions { get; } = new(); + + /// + /// Parameter attributes to convert (e.g., [Range(1, 5)] → [MatrixRange<int>(1, 5)]) + /// + public List ParameterAttributes { get; } = new(); + + /// + /// Usings to add + /// + public List UsingsToAdd { get; } = new(); + + /// + /// Using prefixes to remove (e.g., "Xunit") + /// + public List UsingPrefixesToRemove { get; } = new(); + + /// + /// Failures encountered during analysis + /// + public List Failures { get; } = new(); + + /// + /// Returns true if the plan has any conversions to apply + /// + public bool HasConversions => + Assertions.Count > 0 || + Attributes.Count > 0 || + AttributeRemovals.Count > 0 || + BaseTypeRemovals.Count > 0 || + BaseTypeAdditions.Count > 0 || + ClassAttributeAdditions.Count > 0 || + MethodAttributeAdditions.Count > 0 || + MethodSignatureChanges.Count > 0 || + MemberRemovals.Count > 0 || + ConstructorParameterRemovals.Count > 0 || + RecordExceptionConversions.Count > 0 || + InvocationReplacements.Count > 0 || + TheoryDataConversions.Count > 0 || + ParameterAttributes.Count > 0 || + UsingsToAdd.Count > 0 || + UsingPrefixesToRemove.Count > 0; + + /// + /// Returns true if any failures were recorded during analysis + /// + public bool HasFailures => Failures.Count > 0; +} + +/// +/// Base class for all conversion targets. Uses SyntaxAnnotation to track nodes +/// across syntax tree modifications. +/// +public abstract class ConversionTarget +{ + /// + /// Unique annotation to find this node after tree modifications + /// + public SyntaxAnnotation Annotation { get; } = new SyntaxAnnotation("TUnitMigration", Guid.NewGuid().ToString()); + + /// + /// Original source text for debugging/error messages + /// + public string OriginalText { get; init; } = ""; +} + +/// +/// Represents an assertion to convert +/// +public class AssertionConversion : ConversionTarget +{ + /// + /// The type of assertion conversion to perform + /// + public required AssertionConversionKind Kind { get; init; } + + /// + /// The new assertion code to generate (fully formed) + /// + public required string ReplacementCode { get; init; } + + /// + /// Whether this assertion introduces an await expression + /// + public bool IntroducesAwait { get; init; } + + /// + /// Optional TODO comment to add before the assertion + /// + public string? TodoComment { get; init; } +} + +/// +/// Types of assertion conversions +/// +public enum AssertionConversionKind +{ + // Equality (xUnit naming) + Equal, + NotEqual, + Same, + NotSame, + StrictEqual, + + // Equality (MSTest naming) + AreEqual, + AreNotEqual, + AreSame, + AreNotSame, + + // Boolean (xUnit naming) + True, + False, + + // Boolean (MSTest naming) + IsTrue, + IsFalse, + + // Null (xUnit naming) + Null, + NotNull, + + // Null (MSTest naming) + IsNull, + IsNotNull, + + // Collections (xUnit) + Empty, + NotEmpty, + Single, + Contains, + DoesNotContain, + All, + + // Collections (MSTest CollectionAssert) + CollectionAreEqual, + CollectionAreNotEqual, + CollectionAreEquivalent, + CollectionAreNotEquivalent, + CollectionContains, + CollectionDoesNotContain, + CollectionIsSubsetOf, + CollectionIsNotSubsetOf, + CollectionAllItemsAreUnique, + CollectionAllItemsAreNotNull, + CollectionAllItemsAreInstancesOfType, + + // Exceptions + Throws, + ThrowsAsync, + ThrowsAny, + ThrowsAnyAsync, + ThrowsException, + + // Type checks + IsType, + IsNotType, + IsAssignableFrom, + IsInstanceOfType, + IsNotInstanceOfType, + + // String (xUnit) + StartsWith, + EndsWith, + Matches, + + // String (MSTest StringAssert) + StringContains, + StringStartsWith, + StringEndsWith, + StringMatches, + StringDoesNotMatch, + + // Comparison + InRange, + NotInRange, + + // Other + Fail, + Skip, + Inconclusive, + Collection, + PropertyChanged, + + // Fallback + Unknown +} + +/// +/// Represents an attribute to convert +/// +public class AttributeConversion : ConversionTarget +{ + /// + /// The new attribute name + /// + public required string NewAttributeName { get; init; } + + /// + /// The new argument list (null to keep original, empty string to remove) + /// + public string? NewArgumentList { get; init; } + + /// + /// Additional attributes to add alongside this conversion (e.g., [Skip] from [Fact(Skip = "reason")]) + /// + public List? AdditionalAttributes { get; init; } +} + +/// +/// Represents an additional attribute to add during conversion +/// +public class AdditionalAttribute +{ + /// + /// The attribute name (e.g., "Skip") + /// + public required string Name { get; init; } + + /// + /// The attribute arguments (e.g., "(\"reason\")") + /// + public string? Arguments { get; init; } +} + +/// +/// Represents an attribute to remove entirely +/// +public class AttributeRemoval : ConversionTarget +{ +} + +/// +/// Represents a base type to remove +/// +public class BaseTypeRemoval : ConversionTarget +{ + /// + /// The base type name being removed (for logging) + /// + public required string TypeName { get; init; } +} + +/// +/// Represents a method signature that needs async/Task added +/// +public class MethodSignatureChange : ConversionTarget +{ + /// + /// Whether to add async modifier + /// + public bool AddAsync { get; init; } + + /// + /// Whether to change return type to Task + /// + public bool ChangeReturnTypeToTask { get; init; } + + /// + /// Whether to change return type to ValueTask + /// + public bool ChangeReturnTypeToValueTask { get; init; } + + /// + /// Whether to change return type from ValueTask to Task + /// + public bool ChangeValueTaskToTask { get; init; } + + /// + /// Whether to make the method public (for lifecycle methods) + /// + public bool MakePublic { get; init; } +} + +/// +/// Represents a member (field, property, method) to remove +/// +public class MemberRemoval : ConversionTarget +{ + /// + /// The member name being removed (for logging) + /// + public required string MemberName { get; init; } +} + +/// +/// Represents a constructor parameter to remove +/// +public class ConstructorParameterRemoval : ConversionTarget +{ + /// + /// The parameter name being removed + /// + public required string ParameterName { get; init; } + + /// + /// The parameter type being removed + /// + public required string ParameterType { get; init; } +} + +/// +/// Represents a failure during analysis +/// +public class ConversionFailure +{ + /// + /// The phase where the failure occurred + /// + public required string Phase { get; init; } + + /// + /// Description of what failed + /// + public required string Description { get; init; } + + /// + /// The original code that couldn't be converted + /// + public string? OriginalCode { get; init; } + + /// + /// The exception that caused the failure (if any) + /// + public Exception? Exception { get; init; } +} + +/// +/// Represents a base type (interface) to add to a class +/// +public class BaseTypeAddition : ConversionTarget +{ + /// + /// The interface name to add (e.g., "IAsyncInitializer") + /// + public required string TypeName { get; init; } +} + +/// +/// Represents an attribute to add to a class +/// +public class ClassAttributeAddition : ConversionTarget +{ + /// + /// The attribute code to add (e.g., "ClassDataSource(Shared = SharedType.PerClass)") + /// + public required string AttributeCode { get; init; } +} + +/// +/// Represents an attribute to add to a method +/// +public class MethodAttributeAddition : ConversionTarget +{ + /// + /// The attribute code to add (e.g., "Before(Test)") + /// + public required string AttributeCode { get; init; } + + /// + /// Whether to change the return type from Task to Task (keep as is) or apply other changes + /// + public string? NewReturnType { get; init; } +} + +/// +/// Represents a Record.Exception call that needs to be converted to try-catch +/// +public class RecordExceptionConversion : ConversionTarget +{ + /// + /// The variable name to assign the exception to (e.g., "ex") + /// + public required string VariableName { get; init; } + + /// + /// The body of the lambda to execute in the try block + /// + public required string TryBlockBody { get; init; } +} + +/// +/// Represents an invocation to replace (e.g., _testOutputHelper.WriteLine → Console.WriteLine) +/// +public class InvocationReplacement : ConversionTarget +{ + /// + /// The new invocation code (e.g., "Console.WriteLine(args)") + /// + public required string ReplacementCode { get; init; } +} + +/// +/// Represents a TheoryData field/property that needs to be converted to IEnumerable. +/// This handles both the type declaration and the object creation expression. +/// +public class TheoryDataConversion : ConversionTarget +{ + /// + /// The element type(s) from TheoryData<T> (e.g., "TimeSpan" from TheoryData<TimeSpan>) + /// + public required string ElementType { get; init; } + + /// + /// Annotation for the GenericName (TheoryData<T>) type syntax to convert to IEnumerable<T> + /// + public SyntaxAnnotation? TypeAnnotation { get; init; } + + /// + /// Annotation for the object creation expression to convert to array creation + /// + public SyntaxAnnotation? CreationAnnotation { get; init; } +} + +/// +/// Represents a parameter attribute to convert (e.g., [Range(1, 5)] → [MatrixRange<int>(1, 5)]) +/// +public class ParameterAttributeConversion : ConversionTarget +{ + /// + /// The new attribute name (e.g., "MatrixRange<int>") + /// + public required string NewAttributeName { get; init; } + + /// + /// The new argument list (null to keep original, empty string to remove) + /// + public string? NewArgumentList { get; init; } +} diff --git a/TUnit.Analyzers.CodeFixers/Base/TwoPhase/MigrationAnalyzer.cs b/TUnit.Analyzers.CodeFixers/Base/TwoPhase/MigrationAnalyzer.cs new file mode 100644 index 0000000000..08e323bbdf --- /dev/null +++ b/TUnit.Analyzers.CodeFixers/Base/TwoPhase/MigrationAnalyzer.cs @@ -0,0 +1,561 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; + +namespace TUnit.Analyzers.CodeFixers.Base.TwoPhase; + +/// +/// Phase 1: Analyzes source code while semantic model is valid. +/// Collects all conversion targets into a ConversionPlan. +/// Derived classes implement framework-specific analysis. +/// +/// CRITICAL: All semantic model queries MUST happen on the original root. +/// Annotations are added AFTER all semantic analysis is complete. +/// +public abstract class MigrationAnalyzer +{ + protected SemanticModel SemanticModel { get; } + protected Compilation Compilation { get; } + protected ConversionPlan Plan { get; } + + /// + /// The original, unmodified compilation unit for semantic model queries. + /// + private CompilationUnitSyntax _originalRoot = null!; + + /// + /// Methods that implement interface members (collected during pre-analysis). + /// These should NOT have their signatures changed to async. + /// + private HashSet _interfaceImplementingMethods = new(); + + protected MigrationAnalyzer(SemanticModel semanticModel, Compilation compilation) + { + SemanticModel = semanticModel; + Compilation = compilation; + Plan = new ConversionPlan(); + } + + /// + /// Analyzes the compilation unit and returns a conversion plan. + /// The returned syntax tree has annotations added to mark conversion targets. + /// + public (CompilationUnitSyntax AnnotatedRoot, ConversionPlan Plan) Analyze(CompilationUnitSyntax root) + { + _originalRoot = root; + + // === PRE-ANALYSIS PHASE === + // Collect all semantic information BEFORE any tree modifications. + // The semantic model only works on the original tree. + + // Collect interface-implementing methods (needed for async conversion decision) + CollectInterfaceImplementingMethods(); + + // === ANALYSIS PHASE === + // Now analyze and annotate. Use span-based lookup when semantic info is needed. + var annotatedRoot = root; + + // 1. Analyze assertions (uses semantic model on original nodes) + annotatedRoot = AnalyzeAssertions(annotatedRoot); + + // 2. Analyze attributes + annotatedRoot = AnalyzeAttributes(annotatedRoot); + + // 2b. Analyze parameter attributes (e.g., [Range]) + annotatedRoot = AnalyzeParameterAttributes(annotatedRoot); + + // 2c. Analyze methods for missing attributes (e.g., add [Test] when only [TestCase]) + annotatedRoot = AnalyzeMethodsForMissingAttributes(annotatedRoot); + + // 3. Analyze base types + annotatedRoot = AnalyzeBaseTypes(annotatedRoot); + + // 4. Analyze members (fields, properties for removal) + annotatedRoot = AnalyzeMembers(annotatedRoot); + + // 5. Analyze constructor parameters + annotatedRoot = AnalyzeConstructorParameters(annotatedRoot); + + // 6. Analyze special invocations (e.g., Record.Exception, ITestOutputHelper) + annotatedRoot = AnalyzeSpecialInvocations(annotatedRoot); + + // 7. Analyze TheoryData fields/properties + annotatedRoot = AnalyzeTheoryData(annotatedRoot); + + // 8. Analyze method signatures (uses pre-collected interface info) + annotatedRoot = AnalyzeMethodSignatures(annotatedRoot); + + // 9. Determine usings to add/remove + AnalyzeUsings(); + + return (annotatedRoot, Plan); + } + + /// + /// Pre-analysis: Collects all interface-implementing methods from the original tree. + /// This must be done before any tree modifications. + /// + private void CollectInterfaceImplementingMethods() + { + foreach (var method in _originalRoot.DescendantNodes().OfType()) + { + try + { + var methodSymbol = SemanticModel.GetDeclaredSymbol(method); + if (methodSymbol == null) continue; + + var containingType = methodSymbol.ContainingType; + if (containingType == null) continue; + + foreach (var iface in containingType.AllInterfaces) + { + foreach (var member in iface.GetMembers().OfType()) + { + var impl = containingType.FindImplementationForInterfaceMember(member); + if (SymbolEqualityComparer.Default.Equals(impl, methodSymbol)) + { + _interfaceImplementingMethods.Add(method.Span); + break; + } + } + } + } + catch + { + // Ignore errors in pre-analysis + } + } + } + + /// + /// Checks if a method at the given span implements an interface method. + /// Uses pre-collected data to avoid semantic model queries on modified tree. + /// + protected bool IsInterfaceImplementation(TextSpan methodSpan) + { + return _interfaceImplementingMethods.Contains(methodSpan); + } + + /// + /// Analyzes assertions and adds them to the plan. + /// Returns the root with annotations added to assertion nodes. + /// + protected virtual CompilationUnitSyntax AnalyzeAssertions(CompilationUnitSyntax root) + { + // Find assertions on the ORIGINAL tree (for semantic analysis) + var assertionNodes = FindAssertionNodes(_originalRoot).ToList(); + var currentRoot = root; + + foreach (var originalNode in assertionNodes) + { + try + { + var conversion = AnalyzeAssertion(originalNode); + if (conversion != null) + { + // Check if the containing method has ref/out parameters + // If so, we can't make it async, so use .Wait() instead of await + if (conversion.IntroducesAwait && ContainingMethodHasRefOrOutParameters(originalNode)) + { + // Create a new conversion with .Wait() instead of await + var newReplacementCode = conversion.ReplacementCode; + if (newReplacementCode.StartsWith("await ")) + { + newReplacementCode = newReplacementCode.Substring(6) + ".Wait()"; + } + conversion = new AssertionConversion + { + Kind = conversion.Kind, + OriginalText = conversion.OriginalText, + ReplacementCode = newReplacementCode, + IntroducesAwait = false, + TodoComment = conversion.TodoComment + }; + } + + Plan.Assertions.Add(conversion); + + // Find the corresponding node in the current tree by span + var nodeToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(n => n.Span == originalNode.Span); + + if (nodeToAnnotate != null) + { + var annotatedNode = nodeToAnnotate.WithAdditionalAnnotations(conversion.Annotation); + currentRoot = currentRoot.ReplaceNode(nodeToAnnotate, annotatedNode); + } + } + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "AssertionAnalysis", + Description = ex.Message, + OriginalCode = originalNode.ToString(), + Exception = ex + }); + } + } + + return currentRoot; + } + + /// + /// Checks if the method containing this node has ref or out parameters. + /// Methods with ref/out parameters cannot be made async, so assertions must use .Wait() instead of await. + /// + private static bool ContainingMethodHasRefOrOutParameters(SyntaxNode node) + { + var containingMethod = node.Ancestors() + .OfType() + .FirstOrDefault(); + + if (containingMethod == null) + return false; + + return containingMethod.ParameterList.Parameters + .Any(p => p.Modifiers.Any(m => m.IsKind(SyntaxKind.RefKeyword) || m.IsKind(SyntaxKind.OutKeyword))); + } + + /// + /// Finds all assertion invocation nodes in the syntax tree. + /// Called on the ORIGINAL tree for semantic model compatibility. + /// + protected abstract IEnumerable FindAssertionNodes(CompilationUnitSyntax root); + + /// + /// Analyzes a single assertion and returns the conversion info. + /// Returns null if this node should not be converted. + /// Called with nodes from the ORIGINAL tree. + /// + protected abstract AssertionConversion? AnalyzeAssertion(InvocationExpressionSyntax node); + + /// + /// Analyzes attributes and adds them to the plan. + /// + protected virtual CompilationUnitSyntax AnalyzeAttributes(CompilationUnitSyntax root) + { + var attributeNodes = _originalRoot.DescendantNodes().OfType().ToList(); + var currentRoot = root; + + foreach (var originalNode in attributeNodes) + { + try + { + // Check for removal first + if (ShouldRemoveAttribute(originalNode)) + { + var removal = new AttributeRemoval { OriginalText = originalNode.ToString() }; + Plan.AttributeRemovals.Add(removal); + + var nodeToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(n => n.Span == originalNode.Span); + + if (nodeToAnnotate != null) + { + var annotatedNode = nodeToAnnotate.WithAdditionalAnnotations(removal.Annotation); + currentRoot = currentRoot.ReplaceNode(nodeToAnnotate, annotatedNode); + } + continue; + } + + // Check for conversion + var conversion = AnalyzeAttribute(originalNode); + if (conversion != null) + { + Plan.Attributes.Add(conversion); + + var nodeToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(n => n.Span == originalNode.Span); + + if (nodeToAnnotate != null) + { + var annotatedNode = nodeToAnnotate.WithAdditionalAnnotations(conversion.Annotation); + currentRoot = currentRoot.ReplaceNode(nodeToAnnotate, annotatedNode); + } + } + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "AttributeAnalysis", + Description = ex.Message, + OriginalCode = originalNode.ToString(), + Exception = ex + }); + } + } + + return currentRoot; + } + + /// + /// Returns true if this attribute should be removed entirely. + /// + protected abstract bool ShouldRemoveAttribute(AttributeSyntax node); + + /// + /// Analyzes a single attribute and returns the conversion info. + /// Returns null if this attribute should not be converted. + /// + protected abstract AttributeConversion? AnalyzeAttribute(AttributeSyntax node); + + /// + /// Analyzes parameter attributes (e.g., [Range] on method parameters). + /// + protected virtual CompilationUnitSyntax AnalyzeParameterAttributes(CompilationUnitSyntax root) + { + var parameterNodes = _originalRoot.DescendantNodes().OfType().ToList(); + var currentRoot = root; + + foreach (var parameter in parameterNodes) + { + if (parameter.AttributeLists.Count == 0) continue; + + foreach (var attributeList in parameter.AttributeLists) + { + foreach (var originalAttr in attributeList.Attributes) + { + try + { + var conversion = AnalyzeParameterAttribute(originalAttr, parameter); + if (conversion != null) + { + Plan.ParameterAttributes.Add(conversion); + + var nodeToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(n => n.Span == originalAttr.Span); + + if (nodeToAnnotate != null) + { + var annotatedNode = nodeToAnnotate.WithAdditionalAnnotations(conversion.Annotation); + currentRoot = currentRoot.ReplaceNode(nodeToAnnotate, annotatedNode); + } + } + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "ParameterAttributeAnalysis", + Description = ex.Message, + OriginalCode = originalAttr.ToString(), + Exception = ex + }); + } + } + } + } + + return currentRoot; + } + + /// + /// Analyzes a single parameter attribute and returns the conversion info. + /// Returns null if this attribute should not be converted. + /// + protected virtual ParameterAttributeConversion? AnalyzeParameterAttribute(AttributeSyntax attr, ParameterSyntax parameter) + { + return null; // Default: no parameter attribute conversion + } + + /// + /// Analyzes methods to add missing attributes (e.g., add [Test] when only [TestCase]). + /// + protected virtual CompilationUnitSyntax AnalyzeMethodsForMissingAttributes(CompilationUnitSyntax root) + { + return root; // Default: no missing attribute additions + } + + /// + /// Analyzes base types and adds removals to the plan. + /// + protected virtual CompilationUnitSyntax AnalyzeBaseTypes(CompilationUnitSyntax root) + { + var classNodes = _originalRoot.DescendantNodes().OfType().ToList(); + var currentRoot = root; + + foreach (var classNode in classNodes) + { + if (classNode.BaseList == null) continue; + + foreach (var originalBaseType in classNode.BaseList.Types) + { + try + { + if (ShouldRemoveBaseType(originalBaseType)) + { + var removal = new BaseTypeRemoval + { + TypeName = originalBaseType.Type.ToString(), + OriginalText = originalBaseType.ToString() + }; + Plan.BaseTypeRemovals.Add(removal); + + var nodeToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(n => n.Span == originalBaseType.Span); + + if (nodeToAnnotate != null) + { + var annotatedNode = nodeToAnnotate.WithAdditionalAnnotations(removal.Annotation); + currentRoot = currentRoot.ReplaceNode(nodeToAnnotate, annotatedNode); + } + } + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "BaseTypeAnalysis", + Description = ex.Message, + OriginalCode = originalBaseType.ToString(), + Exception = ex + }); + } + } + } + + return currentRoot; + } + + /// + /// Returns true if this base type should be removed. + /// + protected abstract bool ShouldRemoveBaseType(BaseTypeSyntax baseType); + + /// + /// Analyzes members (fields, properties) for removal. + /// + protected virtual CompilationUnitSyntax AnalyzeMembers(CompilationUnitSyntax root) + { + // Default: no member removals. Override in derived classes. + return root; + } + + /// + /// Analyzes constructor parameters for removal. + /// + protected virtual CompilationUnitSyntax AnalyzeConstructorParameters(CompilationUnitSyntax root) + { + // Default: no parameter removals. Override in derived classes. + return root; + } + + /// + /// Analyzes special invocations (e.g., Record.Exception, ITestOutputHelper.WriteLine). + /// + protected virtual CompilationUnitSyntax AnalyzeSpecialInvocations(CompilationUnitSyntax root) + { + // Default: no special invocations. Override in derived classes. + return root; + } + + /// + /// Analyzes TheoryData fields/properties for conversion to IEnumerable. + /// + protected virtual CompilationUnitSyntax AnalyzeTheoryData(CompilationUnitSyntax root) + { + // Default: no TheoryData conversions. Override in derived classes. + return root; + } + + /// + /// Analyzes method signatures for async conversion. + /// This should be called after assertion analysis to know which methods will have awaits. + /// + protected virtual CompilationUnitSyntax AnalyzeMethodSignatures(CompilationUnitSyntax root) + { + // Find methods that will contain await after transformation + // We need to map from annotated nodes back to original spans + var methodSpansWithAwaits = new HashSet(); + + // Check which methods contain assertions that introduce await + foreach (var assertion in Plan.Assertions.Where(a => a.IntroducesAwait)) + { + // Find the annotated node in the current tree + var assertionNode = root.DescendantNodes() + .FirstOrDefault(n => n.HasAnnotation(assertion.Annotation)); + + if (assertionNode != null) + { + // Find the containing method + var containingMethod = assertionNode.Ancestors() + .OfType() + .FirstOrDefault(); + + if (containingMethod != null) + { + // Use the original span (before annotation was added) + // Since annotations don't change spans, this works + methodSpansWithAwaits.Add(containingMethod.Span); + } + } + } + + var currentRoot = root; + + // Process each method that needs async + foreach (var methodSpan in methodSpansWithAwaits) + { + var method = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(m => m.Span == methodSpan); + + if (method == null) continue; + + // Skip if already async + if (method.Modifiers.Any(SyntaxKind.AsyncKeyword)) + continue; + + // Skip interface implementations (uses pre-collected data) + if (IsInterfaceImplementation(methodSpan)) + continue; + + try + { + var change = new MethodSignatureChange + { + AddAsync = true, + ChangeReturnTypeToTask = method.ReturnType.ToString() == "void", + OriginalText = $"{method.ReturnType} {method.Identifier}" + }; + + Plan.MethodSignatureChanges.Add(change); + + var annotatedMethod = method.WithAdditionalAnnotations(change.Annotation); + currentRoot = currentRoot.ReplaceNode(method, annotatedMethod); + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "MethodSignatureAnalysis", + Description = ex.Message, + OriginalCode = method.Identifier.ToString(), + Exception = ex + }); + } + } + + return currentRoot; + } + + /// + /// Determines which usings to add and remove. + /// + protected abstract void AnalyzeUsings(); + + /// + /// Gets the original root for semantic model queries. + /// Use this when you need to query the semantic model. + /// + protected CompilationUnitSyntax OriginalRoot => _originalRoot; +} diff --git a/TUnit.Analyzers.CodeFixers/Base/TwoPhase/MigrationTransformer.cs b/TUnit.Analyzers.CodeFixers/Base/TwoPhase/MigrationTransformer.cs new file mode 100644 index 0000000000..ed1f7a42ed --- /dev/null +++ b/TUnit.Analyzers.CodeFixers/Base/TwoPhase/MigrationTransformer.cs @@ -0,0 +1,1157 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using TUnit.Analyzers.Migrators.Base; + +namespace TUnit.Analyzers.CodeFixers.Base.TwoPhase; + +/// +/// Phase 2: Applies transformations based on the ConversionPlan. +/// Uses annotations to find nodes that need conversion. +/// No semantic model needed - pure syntax transformations. +/// +public class MigrationTransformer +{ + private readonly ConversionPlan _plan; + private readonly string _frameworkName; + + public MigrationTransformer(ConversionPlan plan, string frameworkName) + { + _plan = plan; + _frameworkName = frameworkName; + } + + /// + /// Applies all transformations from the conversion plan. + /// + public CompilationUnitSyntax Transform(CompilationUnitSyntax root) + { + var currentRoot = root; + + // Apply transformations in dependency order: + // 1. Record.Exception conversions (before assertions - may affect structure) + currentRoot = TransformRecordExceptionCalls(currentRoot); + + // 2. Invocation replacements (ITestOutputHelper → Console) + currentRoot = TransformInvocationReplacements(currentRoot); + + // 3. TheoryData conversions (TheoryData → IEnumerable) + currentRoot = TransformTheoryData(currentRoot); + + // 4. Assertions (may introduce await) + currentRoot = TransformAssertions(currentRoot); + + // 4. Method signatures (add async/Task based on new awaits) + currentRoot = TransformMethodSignatures(currentRoot); + + // 5. Add method attributes (e.g., [Before(Test)]) + currentRoot = AddMethodAttributes(currentRoot); + + // 6. Attributes + currentRoot = TransformAttributes(currentRoot); + + // 6b. Parameter attributes (e.g., [Range] → [MatrixRange]) + currentRoot = TransformParameterAttributes(currentRoot); + + // 7. Remove attributes + currentRoot = RemoveAttributes(currentRoot); + + // 8. Remove base types + currentRoot = RemoveBaseTypes(currentRoot); + + // 9. Add base types (e.g., IAsyncInitializer) + currentRoot = AddBaseTypes(currentRoot); + + // 10. Add class attributes (e.g., ClassDataSource) + currentRoot = AddClassAttributes(currentRoot); + + // 11. Remove members + currentRoot = RemoveMembers(currentRoot); + + // 12. Remove constructor parameters + currentRoot = RemoveConstructorParameters(currentRoot); + + // 13. Update usings (last, pure syntax) + currentRoot = TransformUsings(currentRoot); + + // 14. Add TODO comments for failures + if (_plan.HasFailures) + { + currentRoot = AddFailureComments(currentRoot); + } + + return currentRoot; + } + + private CompilationUnitSyntax TransformRecordExceptionCalls(CompilationUnitSyntax root) + { + var currentRoot = root; + + foreach (var conversion in _plan.RecordExceptionConversions) + { + try + { + // Find the annotated local declaration statement + var statement = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(s => s.HasAnnotation(conversion.Annotation)); + + if (statement == null) continue; + + // Get indentation from the statement + var leadingTrivia = statement.GetLeadingTrivia(); + var indentation = leadingTrivia + .Where(t => t.IsKind(SyntaxKind.WhitespaceTrivia)) + .LastOrDefault() + .ToString(); + + // Build the try body as a statement + var tryBlockBody = conversion.TryBlockBody.Trim(); + // Ensure it ends with semicolon if not already + if (!tryBlockBody.EndsWith(";") && !tryBlockBody.EndsWith("}")) + { + tryBlockBody += ";"; + } + + var tryBodyStatement = SyntaxFactory.ParseStatement(tryBlockBody); + var tryBlock = SyntaxFactory.Block(tryBodyStatement); + + // Build the catch block + var catchAssignment = SyntaxFactory.ParseStatement($"{conversion.VariableName} = e;"); + var catchBlock = SyntaxFactory.Block(catchAssignment); + + var catchClause = SyntaxFactory.CatchClause() + .WithDeclaration( + SyntaxFactory.CatchDeclaration( + SyntaxFactory.IdentifierName("Exception"), + SyntaxFactory.Identifier("e"))) + .WithBlock(catchBlock); + + var tryCatchStatement = SyntaxFactory.TryStatement() + .WithBlock(tryBlock) + .WithCatches(SyntaxFactory.SingletonList(catchClause)); + + // Build the variable declaration with proper trailing newline + var variableDecl = SyntaxFactory.ParseStatement($"Exception? {conversion.VariableName} = null;"); + + // Create a list of statements to replace the original + // Add newline after variable declaration and proper indentation for try statement + var newStatements = new List + { + variableDecl + .WithLeadingTrivia(leadingTrivia) + .WithTrailingTrivia(SyntaxFactory.EndOfLine("\n")), + tryCatchStatement + .WithLeadingTrivia(SyntaxFactory.Whitespace(indentation)) + .WithTrailingTrivia(statement.GetTrailingTrivia()) + }; + + // Find the containing block and replace the statement with the new statements + var containingBlock = statement.Ancestors().OfType().FirstOrDefault(); + if (containingBlock != null) + { + var statementIndex = containingBlock.Statements.IndexOf(statement); + if (statementIndex >= 0) + { + var newStmtList = containingBlock.Statements + .RemoveAt(statementIndex) + .InsertRange(statementIndex, newStatements); + var newBlock = containingBlock.WithStatements(newStmtList); + currentRoot = currentRoot.ReplaceNode(containingBlock, newBlock); + } + } + } + catch (Exception ex) + { + _plan.Failures.Add(new ConversionFailure + { + Phase = "RecordExceptionTransformation", + Description = ex.Message, + OriginalCode = conversion.OriginalText, + Exception = ex + }); + } + } + + return currentRoot; + } + + private CompilationUnitSyntax TransformInvocationReplacements(CompilationUnitSyntax root) + { + var currentRoot = root; + + foreach (var replacement in _plan.InvocationReplacements) + { + try + { + var invocation = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(i => i.HasAnnotation(replacement.Annotation)); + + if (invocation == null) continue; + + // Parse the replacement code + var newInvocation = SyntaxFactory.ParseExpression(replacement.ReplacementCode); + + currentRoot = currentRoot.ReplaceNode(invocation, newInvocation + .WithLeadingTrivia(invocation.GetLeadingTrivia()) + .WithTrailingTrivia(invocation.GetTrailingTrivia())); + } + catch (Exception ex) + { + _plan.Failures.Add(new ConversionFailure + { + Phase = "InvocationReplacementTransformation", + Description = ex.Message, + OriginalCode = replacement.OriginalText, + Exception = ex + }); + } + } + + return currentRoot; + } + + private CompilationUnitSyntax TransformTheoryData(CompilationUnitSyntax root) + { + var currentRoot = root; + + foreach (var conversion in _plan.TheoryDataConversions) + { + try + { + // First, transform the object creation expression to array creation + if (conversion.CreationAnnotation != null) + { + var objectCreation = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(n => n.HasAnnotation(conversion.CreationAnnotation)); + + if (objectCreation?.Initializer != null) + { + // Build array type: T[] + var arrayType = SyntaxFactory.ArrayType( + SyntaxFactory.ParseTypeName(conversion.ElementType), + SyntaxFactory.SingletonList( + SyntaxFactory.ArrayRankSpecifier( + SyntaxFactory.SingletonSeparatedList( + SyntaxFactory.OmittedArraySizeExpression() + ) + ) + ) + ).WithoutTrailingTrivia(); + + // Get the open brace token and ensure it has proper newline trivia + var openBrace = objectCreation.Initializer.OpenBraceToken; + if (!openBrace.LeadingTrivia.Any(t => t.IsKind(SyntaxKind.EndOfLineTrivia))) + { + // Add newline and proper indentation before the brace + openBrace = openBrace.WithLeadingTrivia( + SyntaxFactory.EndOfLine("\n"), + SyntaxFactory.Whitespace(" ")); + } + + // Create array initializer from the collection initializer + var newInitializer = SyntaxFactory.InitializerExpression( + SyntaxKind.ArrayInitializerExpression, + openBrace, + objectCreation.Initializer.Expressions, + objectCreation.Initializer.CloseBraceToken); + + // Build the array creation expression + var newKeyword = SyntaxFactory.Token(SyntaxKind.NewKeyword) + .WithLeadingTrivia(objectCreation.GetLeadingTrivia()) + .WithTrailingTrivia(SyntaxFactory.Space); + + var arrayCreation = SyntaxFactory.ArrayCreationExpression( + newKeyword, + arrayType, + newInitializer + ).WithTrailingTrivia(objectCreation.GetTrailingTrivia()); + + currentRoot = currentRoot.ReplaceNode(objectCreation, arrayCreation); + } + } + + // Then, transform the type declaration from TheoryData to IEnumerable + if (conversion.TypeAnnotation != null) + { + var genericName = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(n => n.HasAnnotation(conversion.TypeAnnotation)); + + if (genericName != null) + { + var enumerableType = SyntaxFactory.GenericName( + SyntaxFactory.Identifier("IEnumerable"), + SyntaxFactory.TypeArgumentList( + SyntaxFactory.SeparatedList(genericName.TypeArgumentList.Arguments))) + .WithLeadingTrivia(genericName.GetLeadingTrivia()) + .WithTrailingTrivia(genericName.GetTrailingTrivia()); + + currentRoot = currentRoot.ReplaceNode(genericName, enumerableType); + } + } + } + catch (Exception ex) + { + _plan.Failures.Add(new ConversionFailure + { + Phase = "TheoryDataTransformation", + Description = ex.Message, + OriginalCode = conversion.OriginalText, + Exception = ex + }); + } + } + + return currentRoot; + } + + private CompilationUnitSyntax TransformAssertions(CompilationUnitSyntax root) + { + var currentRoot = root; + + foreach (var assertion in _plan.Assertions) + { + try + { + var node = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(n => n.HasAnnotation(assertion.Annotation)); + + if (node == null) continue; + + // Parse the replacement code + var replacement = SyntaxFactory.ParseExpression(assertion.ReplacementCode); + + // Find the containing statement + var containingStatement = node.Ancestors() + .OfType() + .FirstOrDefault(); + + if (containingStatement != null) + { + // Build the leading trivia, including TODO comment if present + var leadingTrivia = containingStatement.GetLeadingTrivia(); + if (!string.IsNullOrEmpty(assertion.TodoComment)) + { + // Extract the indentation from existing trivia + var indentationTrivia = leadingTrivia + .Where(t => t.IsKind(SyntaxKind.WhitespaceTrivia)) + .LastOrDefault(); + + var todoTrivia = new List(); + if (indentationTrivia != default) + { + todoTrivia.Add(indentationTrivia); + } + todoTrivia.Add(SyntaxFactory.Comment(assertion.TodoComment)); + todoTrivia.Add(SyntaxFactory.EndOfLine("\n")); + + // Combine TODO comment with existing leading trivia + leadingTrivia = SyntaxFactory.TriviaList(todoTrivia.Concat(leadingTrivia)); + } + + // Replace the entire statement with the new expression statement + var newStatement = SyntaxFactory.ExpressionStatement(replacement) + .WithLeadingTrivia(leadingTrivia) + .WithTrailingTrivia(containingStatement.GetTrailingTrivia()); + + currentRoot = currentRoot.ReplaceNode(containingStatement, newStatement); + } + else + { + // Just replace the expression + currentRoot = currentRoot.ReplaceNode(node, replacement + .WithLeadingTrivia(node.GetLeadingTrivia()) + .WithTrailingTrivia(node.GetTrailingTrivia())); + } + } + catch (Exception ex) + { + _plan.Failures.Add(new ConversionFailure + { + Phase = "AssertionTransformation", + Description = ex.Message, + OriginalCode = assertion.OriginalText, + Exception = ex + }); + } + } + + return currentRoot; + } + + private CompilationUnitSyntax TransformMethodSignatures(CompilationUnitSyntax root) + { + var currentRoot = root; + + foreach (var change in _plan.MethodSignatureChanges) + { + try + { + var method = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(m => m.HasAnnotation(change.Annotation)); + + if (method == null) continue; + + var newMethod = method; + + // Add async modifier if needed + if (change.AddAsync && !method.Modifiers.Any(SyntaxKind.AsyncKeyword)) + { + var asyncToken = SyntaxFactory.Token(SyntaxKind.AsyncKeyword) + .WithTrailingTrivia(SyntaxFactory.Space); + + // Insert async before the return type, after other modifiers + var newModifiers = method.Modifiers.Add(asyncToken); + newMethod = newMethod.WithModifiers(newModifiers); + } + + // Change return type to Task if needed (from void) + if (change.ChangeReturnTypeToTask && method.ReturnType.ToString() == "void") + { + var taskType = SyntaxFactory.IdentifierName("Task") + .WithTrailingTrivia(SyntaxFactory.Space); + newMethod = newMethod.WithReturnType(taskType); + } + + // Change ValueTask to Task if needed (for IAsyncLifetime.InitializeAsync → IAsyncInitializer) + if (change.ChangeValueTaskToTask && method.ReturnType.ToString() == "ValueTask") + { + var taskType = SyntaxFactory.IdentifierName("Task") + .WithTrailingTrivia(SyntaxFactory.Space); + newMethod = newMethod.WithReturnType(taskType); + } + + // Make method public if needed (for lifecycle methods) + if (change.MakePublic) + { + var hasPublicModifier = newMethod.Modifiers.Any(SyntaxKind.PublicKeyword); + if (!hasPublicModifier) + { + // Remove existing access modifiers (private, protected, internal) + var newModifiers = new SyntaxTokenList(); + var publicToken = SyntaxFactory.Token(SyntaxKind.PublicKeyword) + .WithTrailingTrivia(SyntaxFactory.Space); + + // Add public at the start + newModifiers = newModifiers.Add(publicToken); + + // Keep non-access modifiers (static, async, etc.) + foreach (var modifier in newMethod.Modifiers) + { + if (!modifier.IsKind(SyntaxKind.PrivateKeyword) && + !modifier.IsKind(SyntaxKind.ProtectedKeyword) && + !modifier.IsKind(SyntaxKind.InternalKeyword) && + !modifier.IsKind(SyntaxKind.PublicKeyword)) + { + newModifiers = newModifiers.Add(modifier); + } + } + + newMethod = newMethod.WithModifiers(newModifiers); + } + } + + if (newMethod != method) + { + currentRoot = currentRoot.ReplaceNode(method, newMethod); + } + } + catch (Exception ex) + { + _plan.Failures.Add(new ConversionFailure + { + Phase = "MethodSignatureTransformation", + Description = ex.Message, + OriginalCode = change.OriginalText, + Exception = ex + }); + } + } + + return currentRoot; + } + + private CompilationUnitSyntax TransformAttributes(CompilationUnitSyntax root) + { + var currentRoot = root; + + foreach (var conversion in _plan.Attributes) + { + try + { + var attribute = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(a => a.HasAnnotation(conversion.Annotation)); + + if (attribute == null) continue; + + // Build new attribute - handle generic names like ClassDataSource + AttributeSyntax newAttribute; + if (conversion.NewAttributeName.Contains('<')) + { + // Generic attribute name - parse it properly + var fullAttrCode = conversion.NewArgumentList != null && conversion.NewArgumentList.Length > 0 + ? conversion.NewAttributeName + conversion.NewArgumentList + : conversion.NewArgumentList == null && attribute.ArgumentList != null + ? conversion.NewAttributeName + attribute.ArgumentList + : conversion.NewAttributeName; + + newAttribute = ParseAttributeCode(fullAttrCode); + } + else + { + // Use ParseName for qualified names (e.g., System.Obsolete) + var newName = conversion.NewAttributeName.Contains('.') + ? SyntaxFactory.ParseName(conversion.NewAttributeName) + : (NameSyntax)SyntaxFactory.IdentifierName(conversion.NewAttributeName); + newAttribute = SyntaxFactory.Attribute(newName); + + // Add argument list if specified + if (conversion.NewArgumentList != null && conversion.NewArgumentList.Length > 0) + { + var argList = SyntaxFactory.ParseAttributeArgumentList(conversion.NewArgumentList); + newAttribute = newAttribute.WithArgumentList(argList); + } + else if (conversion.NewArgumentList == null && attribute.ArgumentList != null) + { + // Keep original arguments + newAttribute = newAttribute.WithArgumentList(attribute.ArgumentList); + } + } + + // Handle additional attributes (e.g., Skip from Fact(Skip = "reason")) + if (conversion.AdditionalAttributes != null && conversion.AdditionalAttributes.Count > 0) + { + // Find the containing attribute list + var attributeList = attribute.Ancestors() + .OfType() + .FirstOrDefault(); + + if (attributeList != null) + { + // Create separate attribute lists - each additional attribute on its own line + var newAttributeLists = new List(); + + // Extract just the indentation from leading trivia (whitespace at the end) + var fullLeadingTrivia = attributeList.GetLeadingTrivia(); + var indentationTrivia = SyntaxFactory.TriviaList( + fullLeadingTrivia.Where(t => t.IsKind(SyntaxKind.WhitespaceTrivia))); + + // First, create the attribute list with the main converted attribute + // Keep the full leading trivia (including any blank lines before) for the first attribute + var mainAttrList = SyntaxFactory.AttributeList( + SyntaxFactory.SingletonSeparatedList(newAttribute)) + .WithLeadingTrivia(fullLeadingTrivia) + .WithTrailingTrivia(SyntaxFactory.EndOfLine("\n")); + newAttributeLists.Add(mainAttrList); + + // Create separate attribute lists for each additional attribute + foreach (var additional in conversion.AdditionalAttributes) + { + var additionalAttr = SyntaxFactory.Attribute( + SyntaxFactory.IdentifierName(additional.Name)); + + if (!string.IsNullOrEmpty(additional.Arguments)) + { + additionalAttr = additionalAttr.WithArgumentList( + SyntaxFactory.ParseAttributeArgumentList(additional.Arguments)); + } + + // Use only indentation for additional attributes (no blank lines) + var additionalAttrList = SyntaxFactory.AttributeList( + SyntaxFactory.SingletonSeparatedList(additionalAttr)) + .WithLeadingTrivia(indentationTrivia) + .WithTrailingTrivia(SyntaxFactory.EndOfLine("\n")); + newAttributeLists.Add(additionalAttrList); + } + + // Replace original attribute list with multiple new ones + currentRoot = currentRoot.ReplaceNode(attributeList, newAttributeLists); + continue; + } + } + + currentRoot = currentRoot.ReplaceNode(attribute, newAttribute + .WithLeadingTrivia(attribute.GetLeadingTrivia()) + .WithTrailingTrivia(attribute.GetTrailingTrivia())); + } + catch (Exception ex) + { + _plan.Failures.Add(new ConversionFailure + { + Phase = "AttributeTransformation", + Description = ex.Message, + OriginalCode = conversion.OriginalText, + Exception = ex + }); + } + } + + return currentRoot; + } + + private CompilationUnitSyntax TransformParameterAttributes(CompilationUnitSyntax root) + { + var currentRoot = root; + + foreach (var conversion in _plan.ParameterAttributes) + { + try + { + var attribute = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(a => a.HasAnnotation(conversion.Annotation)); + + if (attribute == null) continue; + + // Build new attribute - handle generic names like MatrixRange + AttributeSyntax newAttribute; + if (conversion.NewAttributeName.Contains('<')) + { + // Generic attribute name - parse it properly + var fullAttrCode = conversion.NewArgumentList != null && conversion.NewArgumentList.Length > 0 + ? conversion.NewAttributeName + conversion.NewArgumentList + : conversion.NewArgumentList == null && attribute.ArgumentList != null + ? conversion.NewAttributeName + attribute.ArgumentList + : conversion.NewAttributeName; + + newAttribute = ParseAttributeCode(fullAttrCode); + } + else + { + var newName = (NameSyntax)SyntaxFactory.IdentifierName(conversion.NewAttributeName); + newAttribute = SyntaxFactory.Attribute(newName); + + // Add argument list if specified + if (conversion.NewArgumentList != null && conversion.NewArgumentList.Length > 0) + { + var argList = SyntaxFactory.ParseAttributeArgumentList(conversion.NewArgumentList); + newAttribute = newAttribute.WithArgumentList(argList); + } + else if (conversion.NewArgumentList == null && attribute.ArgumentList != null) + { + // Keep original arguments + newAttribute = newAttribute.WithArgumentList(attribute.ArgumentList); + } + } + + // Preserve trivia + newAttribute = newAttribute + .WithLeadingTrivia(attribute.GetLeadingTrivia()) + .WithTrailingTrivia(attribute.GetTrailingTrivia()); + + currentRoot = currentRoot.ReplaceNode(attribute, newAttribute); + } + catch (Exception ex) + { + _plan.Failures.Add(new ConversionFailure + { + Phase = "ParameterAttributeTransformation", + Description = ex.Message, + OriginalCode = conversion.OriginalText, + Exception = ex + }); + } + } + + return currentRoot; + } + + private CompilationUnitSyntax RemoveAttributes(CompilationUnitSyntax root) + { + var currentRoot = root; + + foreach (var removal in _plan.AttributeRemovals) + { + try + { + var attribute = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(a => a.HasAnnotation(removal.Annotation)); + + if (attribute == null) continue; + + // Find the attribute list containing this attribute + var attributeList = attribute.Ancestors() + .OfType() + .FirstOrDefault(); + + if (attributeList == null) continue; + + if (attributeList.Attributes.Count == 1) + { + // Remove the entire attribute list without keeping its trivia + // This prevents extra indentation from being left behind + currentRoot = currentRoot.RemoveNode(attributeList, SyntaxRemoveOptions.KeepNoTrivia)!; + } + else + { + // Remove just this attribute from the list + var newAttributes = attributeList.Attributes.Remove(attribute); + var newList = attributeList.WithAttributes(newAttributes); + currentRoot = currentRoot.ReplaceNode(attributeList, newList); + } + } + catch (Exception ex) + { + _plan.Failures.Add(new ConversionFailure + { + Phase = "AttributeRemoval", + Description = ex.Message, + OriginalCode = removal.OriginalText, + Exception = ex + }); + } + } + + return currentRoot; + } + + private CompilationUnitSyntax RemoveBaseTypes(CompilationUnitSyntax root) + { + var currentRoot = root; + + foreach (var removal in _plan.BaseTypeRemovals) + { + try + { + var baseType = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(b => b.HasAnnotation(removal.Annotation)); + + if (baseType == null) continue; + + var baseList = baseType.Ancestors() + .OfType() + .FirstOrDefault(); + + if (baseList == null) continue; + + if (baseList.Types.Count == 1) + { + // Remove the entire base list + var classDecl = baseList.Ancestors() + .OfType() + .FirstOrDefault(); + + if (classDecl != null) + { + var newClass = classDecl.WithBaseList(null); + + // Remove trailing trivia from the element before the base list + // This could be ParameterList, TypeParameterList, or Identifier + if (classDecl.ParameterList != null) + { + // Primary constructor - remove trailing trivia from the close paren + var paramList = classDecl.ParameterList; + var closeParen = paramList.CloseParenToken.WithTrailingTrivia(SyntaxFactory.TriviaList()); + newClass = newClass.WithParameterList(paramList.WithCloseParenToken(closeParen)); + } + else if (classDecl.TypeParameterList != null) + { + // Generic class - remove trailing trivia from the close angle bracket + var typeParamList = classDecl.TypeParameterList; + var closeAngle = typeParamList.GreaterThanToken.WithTrailingTrivia(SyntaxFactory.TriviaList()); + newClass = newClass.WithTypeParameterList(typeParamList.WithGreaterThanToken(closeAngle)); + } + else + { + // Regular class - the identifier might have trailing trivia + newClass = newClass.WithIdentifier(classDecl.Identifier.WithTrailingTrivia(SyntaxFactory.TriviaList())); + } + + // Preserve the open brace trivia - it should have newline before it + if (classDecl.OpenBraceToken != default) + { + var openBrace = classDecl.OpenBraceToken; + if (!openBrace.LeadingTrivia.Any(t => t.IsKind(SyntaxKind.EndOfLineTrivia))) + { + // Add newline before the open brace + openBrace = openBrace.WithLeadingTrivia(SyntaxFactory.EndOfLine("\n")); + } + newClass = newClass.WithOpenBraceToken(openBrace); + } + + currentRoot = currentRoot.ReplaceNode(classDecl, newClass); + } + } + else + { + // Remove just this base type + var newTypes = baseList.Types.Remove(baseType); + var newList = baseList.WithTypes(newTypes); + currentRoot = currentRoot.ReplaceNode(baseList, newList); + } + } + catch (Exception ex) + { + _plan.Failures.Add(new ConversionFailure + { + Phase = "BaseTypeRemoval", + Description = ex.Message, + OriginalCode = removal.OriginalText, + Exception = ex + }); + } + } + + return currentRoot; + } + + private CompilationUnitSyntax AddBaseTypes(CompilationUnitSyntax root) + { + var currentRoot = root; + + foreach (var addition in _plan.BaseTypeAdditions) + { + try + { + // Find the class declaration that has the annotation + var classDecl = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(c => c.HasAnnotation(addition.Annotation)); + + if (classDecl == null) continue; + + // Create the new base type + var newBaseType = SyntaxFactory.SimpleBaseType( + SyntaxFactory.ParseTypeName(addition.TypeName)); + + ClassDeclarationSyntax newClass; + if (classDecl.BaseList == null) + { + // Create new base list + var baseList = SyntaxFactory.BaseList( + SyntaxFactory.SingletonSeparatedList(newBaseType)) + .WithColonToken(SyntaxFactory.Token(SyntaxKind.ColonToken).WithTrailingTrivia(SyntaxFactory.Space)); + newClass = classDecl.WithBaseList(baseList); + } + else + { + // Add to existing base list + var newTypes = classDecl.BaseList.Types.Add(newBaseType); + var newBaseList = classDecl.BaseList.WithTypes(newTypes); + newClass = classDecl.WithBaseList(newBaseList); + } + + currentRoot = currentRoot.ReplaceNode(classDecl, newClass); + } + catch (Exception ex) + { + _plan.Failures.Add(new ConversionFailure + { + Phase = "BaseTypeAddition", + Description = ex.Message, + OriginalCode = addition.OriginalText, + Exception = ex + }); + } + } + + return currentRoot; + } + + private CompilationUnitSyntax AddClassAttributes(CompilationUnitSyntax root) + { + var currentRoot = root; + + foreach (var addition in _plan.ClassAttributeAdditions) + { + try + { + // Find the class declaration that has the annotation + var classDecl = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(c => c.HasAnnotation(addition.Annotation)); + + if (classDecl == null) continue; + + // Parse the complete attribute (handles generic types and arguments) + // Format: "ClassDataSource(Shared = SharedType.PerClass)" + var attrCode = addition.AttributeCode; + var attribute = ParseAttributeCode(attrCode); + + // Preserve class's leading trivia and put attribute list before it + var classLeadingTrivia = classDecl.GetLeadingTrivia(); + var attributeList = SyntaxFactory.AttributeList( + SyntaxFactory.SingletonSeparatedList(attribute)) + .WithLeadingTrivia(classLeadingTrivia) + .WithTrailingTrivia(SyntaxFactory.EndOfLine("\n")); + + // Remove leading trivia from class (it's now on the attribute) + var newClass = classDecl + .WithLeadingTrivia(SyntaxFactory.TriviaList()) + .AddAttributeLists(attributeList); + + currentRoot = currentRoot.ReplaceNode(classDecl, newClass); + } + catch (Exception ex) + { + _plan.Failures.Add(new ConversionFailure + { + Phase = "ClassAttributeAddition", + Description = ex.Message, + OriginalCode = addition.OriginalText, + Exception = ex + }); + } + } + + return currentRoot; + } + + /// + /// Parses an attribute code string that may include generic type arguments and parameters. + /// Format: "AttributeName" or "GenericAttr" or "Attr(args)" or "GenericAttr(args)" + /// + private static AttributeSyntax ParseAttributeCode(string attrCode) + { + // Parse as a complete attribute by wrapping in a dummy class + var code = $"[{attrCode}] class Dummy {{ }}"; + var tree = CSharpSyntaxTree.ParseText(code); + var attr = tree.GetRoot() + .DescendantNodes() + .OfType() + .FirstOrDefault(); + + return attr ?? SyntaxFactory.Attribute(SyntaxFactory.IdentifierName(attrCode)); + } + + private CompilationUnitSyntax AddMethodAttributes(CompilationUnitSyntax root) + { + var currentRoot = root; + + foreach (var addition in _plan.MethodAttributeAdditions) + { + try + { + // Find the method declaration that has the annotation + var method = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(m => m.HasAnnotation(addition.Annotation)); + + if (method == null) continue; + + // Parse and create the new attribute + // AttributeCode may include arguments like "Before(Test)" or just a name like "Test" + AttributeSyntax attribute; + var parenIndex = addition.AttributeCode.IndexOf('('); + if (parenIndex > 0) + { + // Has arguments: split into name and argument list + var attrName = addition.AttributeCode.Substring(0, parenIndex); + var argList = addition.AttributeCode.Substring(parenIndex); + attribute = SyntaxFactory.Attribute( + SyntaxFactory.ParseName(attrName), + SyntaxFactory.ParseAttributeArgumentList(argList)); + } + else + { + // No arguments: just the name + attribute = SyntaxFactory.Attribute(SyntaxFactory.ParseName(addition.AttributeCode)); + } + // Get the leading trivia from the first attribute (if any) or method + var leadingTrivia = method.AttributeLists.Count > 0 + ? method.AttributeLists[0].GetLeadingTrivia() + : method.GetLeadingTrivia(); + + // The leading trivia typically contains: [newlines...] [whitespace for indent] + // We want to: + // 1. Put the full leading trivia (including blank lines) on the new [Test] attribute + // 2. Put ONLY the trailing whitespace (indentation) on the first existing attribute + var triviaList = leadingTrivia.ToList(); + + // Find the last whitespace trivia - that's the indentation + var lastWhitespaceIndex = -1; + for (int i = triviaList.Count - 1; i >= 0; i--) + { + if (triviaList[i].IsKind(SyntaxKind.WhitespaceTrivia)) + { + lastWhitespaceIndex = i; + break; + } + } + + // Indentation is just the final whitespace trivia (if any) + var indentationTrivia = lastWhitespaceIndex >= 0 + ? SyntaxFactory.TriviaList(triviaList[lastWhitespaceIndex]) + : SyntaxFactory.TriviaList(); + + // Build the new list of attribute lists manually + var newAttributeLists = new List(); + + // Add the new [Test] attribute first with full leading trivia + // Explicitly clear trailing trivia - the first existing attribute's leading trivia has the newline + var newTestAttrList = SyntaxFactory.AttributeList( + SyntaxFactory.SingletonSeparatedList(attribute)) + .WithLeadingTrivia(leadingTrivia) + .WithTrailingTrivia(SyntaxFactory.TriviaList()); + newAttributeLists.Add(newTestAttrList); + + // Add existing attributes, updating the first one's leading trivia + for (int i = 0; i < method.AttributeLists.Count; i++) + { + var existingAttr = method.AttributeLists[i]; + if (i == 0) + { + // First existing attribute: need newline + indentation + // The newline separates it from [Test], the indentation aligns it + var newLeading = SyntaxFactory.TriviaList( + SyntaxFactory.EndOfLine("\n")) + .AddRange(indentationTrivia); + existingAttr = existingAttr.WithLeadingTrivia(newLeading); + } + newAttributeLists.Add(existingAttr); + } + + var newMethod = method + .WithLeadingTrivia(SyntaxFactory.TriviaList()) + .WithAttributeLists(SyntaxFactory.List(newAttributeLists)); + + // Change return type if specified + if (!string.IsNullOrEmpty(addition.NewReturnType)) + { + newMethod = newMethod.WithReturnType( + SyntaxFactory.ParseTypeName(addition.NewReturnType) + .WithTrailingTrivia(SyntaxFactory.Space)); + } + + currentRoot = currentRoot.ReplaceNode(method, newMethod); + } + catch (Exception ex) + { + _plan.Failures.Add(new ConversionFailure + { + Phase = "MethodAttributeAddition", + Description = ex.Message, + OriginalCode = addition.OriginalText, + Exception = ex + }); + } + } + + return currentRoot; + } + + private CompilationUnitSyntax RemoveMembers(CompilationUnitSyntax root) + { + var currentRoot = root; + + foreach (var removal in _plan.MemberRemovals) + { + try + { + var member = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(m => m.HasAnnotation(removal.Annotation)); + + if (member == null) continue; + + // Use KeepTrailingTrivia to preserve the newline after the member, + // but NOT KeepLeadingTrivia which would transfer the member's indentation to the next node + currentRoot = currentRoot.RemoveNode(member, SyntaxRemoveOptions.KeepTrailingTrivia)!; + } + catch (Exception ex) + { + _plan.Failures.Add(new ConversionFailure + { + Phase = "MemberRemoval", + Description = ex.Message, + OriginalCode = removal.OriginalText, + Exception = ex + }); + } + } + + return currentRoot; + } + + private CompilationUnitSyntax RemoveConstructorParameters(CompilationUnitSyntax root) + { + var currentRoot = root; + + foreach (var removal in _plan.ConstructorParameterRemovals) + { + try + { + var parameter = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(p => p.HasAnnotation(removal.Annotation)); + + if (parameter == null) continue; + + var parameterList = parameter.Ancestors() + .OfType() + .FirstOrDefault(); + + if (parameterList == null) continue; + + var newParams = parameterList.Parameters.Remove(parameter); + var newList = parameterList.WithParameters(newParams); + currentRoot = currentRoot.ReplaceNode(parameterList, newList); + } + catch (Exception ex) + { + _plan.Failures.Add(new ConversionFailure + { + Phase = "ConstructorParameterRemoval", + Description = ex.Message, + OriginalCode = removal.OriginalText, + Exception = ex + }); + } + } + + return currentRoot; + } + + private CompilationUnitSyntax TransformUsings(CompilationUnitSyntax root) + { + // Remove framework usings + root = MigrationHelpers.RemoveFrameworkUsings(root, _frameworkName); + + // Add TUnit usings (handled by MigrationHelpers which checks for async code, File/Directory usage, etc.) + root = MigrationHelpers.AddTUnitUsings(root); + + return root; + } + + private CompilationUnitSyntax AddFailureComments(CompilationUnitSyntax root) + { + var failureSummary = _plan.Failures + .GroupBy(f => f.Phase) + .Select(g => $"// TODO: TUnit migration - {g.Key}: {g.Count()} item(s) could not be converted") + .ToList(); + + if (failureSummary.Count == 0) + { + return root; + } + + var commentTrivia = new List + { + SyntaxFactory.Comment("// ============================================================"), + SyntaxFactory.EndOfLine("\n"), + SyntaxFactory.Comment("// TUnit Migration: Some items require manual attention"), + SyntaxFactory.EndOfLine("\n") + }; + + foreach (var summary in failureSummary) + { + commentTrivia.Add(SyntaxFactory.Comment(summary)); + commentTrivia.Add(SyntaxFactory.EndOfLine("\n")); + } + + commentTrivia.Add(SyntaxFactory.Comment("// ============================================================")); + commentTrivia.Add(SyntaxFactory.EndOfLine("\n")); + commentTrivia.Add(SyntaxFactory.EndOfLine("\n")); + + var existingTrivia = root.GetLeadingTrivia(); + return root.WithLeadingTrivia(SyntaxFactory.TriviaList(commentTrivia).AddRange(existingTrivia)); + } +} diff --git a/TUnit.Analyzers.CodeFixers/MSTestMigrationCodeFixProvider.cs b/TUnit.Analyzers.CodeFixers/MSTestMigrationCodeFixProvider.cs index 685c3c4d82..5dcb09e5f6 100644 --- a/TUnit.Analyzers.CodeFixers/MSTestMigrationCodeFixProvider.cs +++ b/TUnit.Analyzers.CodeFixers/MSTestMigrationCodeFixProvider.cs @@ -4,6 +4,8 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using TUnit.Analyzers.CodeFixers.Base; +using TUnit.Analyzers.CodeFixers.Base.TwoPhase; +using TUnit.Analyzers.CodeFixers.TwoPhase; using TUnit.Analyzers.Migrators.Base; namespace TUnit.Analyzers.CodeFixers; @@ -14,7 +16,17 @@ public class MSTestMigrationCodeFixProvider : BaseMigrationCodeFixProvider protected override string FrameworkName => "MSTest"; protected override string DiagnosticId => Rules.MSTestMigration.Id; protected override string CodeFixTitle => "Convert MSTest code to TUnit"; - + + protected override bool ShouldAddTUnitUsings() => true; + + protected override MigrationAnalyzer? CreateTwoPhaseAnalyzer(SemanticModel semanticModel, Compilation compilation) + { + return new MSTestTwoPhaseAnalyzer(semanticModel, compilation); + } + + // The following methods are required by the base class but are only used in the legacy + // conversion path. The two-phase analyzer handles these conversions directly. + protected override AttributeRewriter CreateAttributeRewriter(Compilation compilation) { return new MSTestAttributeRewriter(); @@ -43,6 +55,20 @@ protected override CompilationUnitSyntax ApplyFrameworkSpecificConversions(Compi return compilationUnit; } + + protected override CompilationUnitSyntax ApplyTwoPhasePostTransformations(CompilationUnitSyntax compilationUnit) + { + // Handle [ExpectedException] attribute conversion - this is a syntax-only transformation + // that converts [ExpectedException(typeof(T))] to await Assert.ThrowsAsync() + var expectedExceptionRewriter = new MSTestExpectedExceptionRewriter(); + compilationUnit = (CompilationUnitSyntax)expectedExceptionRewriter.Visit(compilationUnit); + + // Handle lifecycle method transformations - removes TestContext parameter + var lifecycleRewriter = new MSTestLifecycleRewriter(); + compilationUnit = (CompilationUnitSyntax)lifecycleRewriter.Visit(compilationUnit); + + return compilationUnit; + } } public class MSTestAttributeRewriter : AttributeRewriter @@ -950,15 +976,31 @@ public class MSTestLifecycleRewriter : CSharpSyntaxRewriter public override SyntaxNode? VisitMethodDeclaration(MethodDeclarationSyntax node) { // Handle ClassInitialize, ClassCleanup, TestInitialize, TestCleanup - remove TestContext parameter where applicable - var lifecycleAttributes = node.AttributeLists + // Also check for the converted TUnit attribute names (Before/After with HookType.Class) + var allAttributes = node.AttributeLists .SelectMany(al => al.Attributes) + .ToList(); + + var lifecycleAttributeNames = allAttributes .Select(a => MigrationHelpers.GetAttributeName(a)) .ToList(); - var hasClassLifecycle = lifecycleAttributes.Any(name => name is "ClassInitialize" or "ClassCleanup"); - var hasTestLifecycle = lifecycleAttributes.Any(name => name is "TestInitialize" or "TestCleanup"); + var hasClassLifecycle = lifecycleAttributeNames.Any(name => name is "ClassInitialize" or "ClassCleanup"); + var hasTestLifecycle = lifecycleAttributeNames.Any(name => name is "TestInitialize" or "TestCleanup"); - if (hasClassLifecycle || hasTestLifecycle) + // Also check for converted TUnit attributes: [Before(HookType.Class)] or [After(HookType.Class)] + var hasTUnitClassLifecycle = allAttributes.Any(a => + { + var attrName = MigrationHelpers.GetAttributeName(a); + if (attrName is not ("Before" or "After")) + return false; + + // Check if it has HookType.Class argument + return a.ArgumentList?.Arguments.Any(arg => + arg.Expression.ToString().Contains("HookType.Class")) == true; + }); + + if (hasClassLifecycle || hasTestLifecycle || hasTUnitClassLifecycle) { // Remove TestContext parameter if present var parameters = node.ParameterList?.Parameters ?? default; @@ -973,8 +1015,8 @@ public class MSTestLifecycleRewriter : CSharpSyntaxRewriter node = node.AddModifiers(SyntaxFactory.Token(SyntaxKind.PublicKeyword)); } - // Make sure ClassInitialize/ClassCleanup are static - if (hasClassLifecycle && !node.Modifiers.Any(SyntaxKind.StaticKeyword)) + // Make sure ClassInitialize/ClassCleanup are static (and converted TUnit class-level hooks) + if ((hasClassLifecycle || hasTUnitClassLifecycle) && !node.Modifiers.Any(SyntaxKind.StaticKeyword)) { node = node.AddModifiers(SyntaxFactory.Token(SyntaxKind.StaticKeyword)); } @@ -1022,11 +1064,31 @@ public class MSTestExpectedExceptionRewriter : CSharpSyntaxRewriter return node.WithAttributeLists(SyntaxFactory.List(newAttributeLists)); } - return node + var result = node .WithAttributeLists(SyntaxFactory.List(newAttributeLists)) .WithBody(newBody) .WithExpressionBody(null) .WithSemicolonToken(default); + + // If the method isn't already async, make it async Task + var isAlreadyAsync = result.Modifiers.Any(m => m.IsKind(SyntaxKind.AsyncKeyword)); + if (!isAlreadyAsync) + { + var asyncModifier = SyntaxFactory.Token(SyntaxKind.AsyncKeyword).WithTrailingTrivia(SyntaxFactory.Space); + var newModifiers = result.Modifiers.Add(asyncModifier); + result = result.WithModifiers(newModifiers); + + // Change void to Task + if (result.ReturnType is PredefinedTypeSyntax predefined && + predefined.Keyword.IsKind(SyntaxKind.VoidKeyword)) + { + var taskReturnType = SyntaxFactory.IdentifierName("Task") + .WithTrailingTrivia(result.ReturnType.GetTrailingTrivia()); + result = result.WithReturnType(taskReturnType); + } + } + + return result; } private static TypeSyntax? ExtractExceptionType(AttributeSyntax attribute) diff --git a/TUnit.Analyzers.CodeFixers/NUnitAssertMultipleRewriter.cs b/TUnit.Analyzers.CodeFixers/NUnitAssertMultipleRewriter.cs new file mode 100644 index 0000000000..75fc067c82 --- /dev/null +++ b/TUnit.Analyzers.CodeFixers/NUnitAssertMultipleRewriter.cs @@ -0,0 +1,112 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace TUnit.Analyzers.CodeFixers; + +/// +/// Converts Assert.Multiple(() => { ... }) to using (Assert.Multiple()) { ... } +/// This is a syntax-only transformation that doesn't require a semantic model. +/// +public class NUnitAssertMultipleRewriter : CSharpSyntaxRewriter +{ + public override SyntaxNode? VisitExpressionStatement(ExpressionStatementSyntax node) + { + // Check if this is Assert.Multiple(() => { ... }) + if (node.Expression is InvocationExpressionSyntax invocation && + invocation.Expression is MemberAccessExpressionSyntax memberAccess && + memberAccess.Expression is IdentifierNameSyntax { Identifier.Text: "Assert" } && + memberAccess.Name.Identifier.Text == "Multiple" && + invocation.ArgumentList.Arguments.Count == 1) + { + var argument = invocation.ArgumentList.Arguments[0].Expression; + + // Handle lambda: Assert.Multiple(() => { ... }) + if (argument is ParenthesizedLambdaExpressionSyntax lambda) + { + return ConvertAssertMultipleLambda(node, lambda); + } + + // Handle simple lambda: Assert.Multiple(() => expr) + if (argument is SimpleLambdaExpressionSyntax simpleLambda) + { + return ConvertAssertMultipleSimpleLambda(node, simpleLambda); + } + } + + return base.VisitExpressionStatement(node); + } + + private SyntaxNode ConvertAssertMultipleLambda(ExpressionStatementSyntax originalStatement, ParenthesizedLambdaExpressionSyntax lambda) + { + // Extract statements from lambda body + SyntaxList statements; + if (lambda.Body is BlockSyntax block) + { + // Visit each statement to convert inner assertions + var convertedStatements = block.Statements.Select(s => (StatementSyntax)Visit(s)!).ToArray(); + statements = SyntaxFactory.List(convertedStatements); + } + else if (lambda.Body is ExpressionSyntax expr) + { + // Single expression lambda - convert it + var visitedExpr = (ExpressionSyntax)Visit(expr)!; + statements = SyntaxFactory.SingletonList( + SyntaxFactory.ExpressionStatement(visitedExpr)); + } + else + { + return originalStatement; + } + + return CreateUsingMultipleStatement(originalStatement, statements); + } + + private SyntaxNode ConvertAssertMultipleSimpleLambda(ExpressionStatementSyntax originalStatement, SimpleLambdaExpressionSyntax lambda) + { + SyntaxList statements; + if (lambda.Body is BlockSyntax block) + { + var convertedStatements = block.Statements.Select(s => (StatementSyntax)Visit(s)!).ToArray(); + statements = SyntaxFactory.List(convertedStatements); + } + else if (lambda.Body is ExpressionSyntax expr) + { + var visitedExpr = (ExpressionSyntax)Visit(expr)!; + statements = SyntaxFactory.SingletonList( + SyntaxFactory.ExpressionStatement(visitedExpr)); + } + else + { + return originalStatement; + } + + return CreateUsingMultipleStatement(originalStatement, statements); + } + + private UsingStatementSyntax CreateUsingMultipleStatement(ExpressionStatementSyntax originalStatement, SyntaxList statements) + { + // Create: Assert.Multiple() + var assertMultipleInvocation = SyntaxFactory.InvocationExpression( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName("Assert"), + SyntaxFactory.IdentifierName("Multiple")), + SyntaxFactory.ArgumentList()); + + // Create the using statement: using (Assert.Multiple()) { ... } + var usingStatement = SyntaxFactory.UsingStatement( + declaration: null, + expression: assertMultipleInvocation, + statement: SyntaxFactory.Block(statements) + .WithOpenBraceToken(SyntaxFactory.Token(SyntaxKind.OpenBraceToken).WithLeadingTrivia(SyntaxFactory.LineFeed)) + .WithCloseBraceToken(SyntaxFactory.Token(SyntaxKind.CloseBraceToken).WithLeadingTrivia(originalStatement.GetLeadingTrivia()))); + + return usingStatement + .WithUsingKeyword(SyntaxFactory.Token(SyntaxKind.UsingKeyword).WithTrailingTrivia(SyntaxFactory.Space)) + .WithOpenParenToken(SyntaxFactory.Token(SyntaxKind.OpenParenToken)) + .WithCloseParenToken(SyntaxFactory.Token(SyntaxKind.CloseParenToken)) + .WithLeadingTrivia(originalStatement.GetLeadingTrivia()) + .WithTrailingTrivia(originalStatement.GetTrailingTrivia()); + } +} diff --git a/TUnit.Analyzers.CodeFixers/NUnitExpectedResultRewriter.cs b/TUnit.Analyzers.CodeFixers/NUnitExpectedResultRewriter.cs index 45046abf86..cc39d05877 100644 --- a/TUnit.Analyzers.CodeFixers/NUnitExpectedResultRewriter.cs +++ b/TUnit.Analyzers.CodeFixers/NUnitExpectedResultRewriter.cs @@ -524,10 +524,11 @@ private AttributeSyntax TransformTestCaseAttribute(AttributeSyntax attribute) newArgs.Add(categoriesArg); } - var newAttribute = attribute.WithArgumentList( + // Create a new Arguments attribute (renamed from TestCase) + var newAttribute = SyntaxFactory.Attribute( + SyntaxFactory.IdentifierName("Arguments"), SyntaxFactory.AttributeArgumentList(SyntaxFactory.SeparatedList(newArgs))); - // The attribute will be renamed to "Arguments" by the existing attribute rewriter return newAttribute; } diff --git a/TUnit.Analyzers.CodeFixers/NUnitMigrationCodeFixProvider.cs b/TUnit.Analyzers.CodeFixers/NUnitMigrationCodeFixProvider.cs index a731f3dd84..e758025391 100644 --- a/TUnit.Analyzers.CodeFixers/NUnitMigrationCodeFixProvider.cs +++ b/TUnit.Analyzers.CodeFixers/NUnitMigrationCodeFixProvider.cs @@ -4,6 +4,8 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using TUnit.Analyzers.CodeFixers.Base; +using TUnit.Analyzers.CodeFixers.Base.TwoPhase; +using TUnit.Analyzers.CodeFixers.TwoPhase; using TUnit.Analyzers.Migrators.Base; namespace TUnit.Analyzers.CodeFixers; @@ -57,6 +59,31 @@ protected override CompilationUnitSyntax ApplyFrameworkSpecificConversions(Compi /// NUnit allows [TestCase] alone, but TUnit requires [Test] + [Arguments]. /// protected override bool ShouldEnsureTestAttribute() => true; + + protected override bool ShouldAddTUnitUsings() => true; + + protected override MigrationAnalyzer? CreateTwoPhaseAnalyzer(SemanticModel semanticModel, Compilation compilation) + { + return new NUnitTwoPhaseAnalyzer(semanticModel, compilation); + } + + protected override CompilationUnitSyntax ApplyTwoPhasePostTransformations(CompilationUnitSyntax compilationUnit) + { + // Transform ExpectedResult patterns (TestCase with ExpectedResult → Arguments with assertion) + // The ExpectedResultRewriter doesn't actually need the semantic model (it's kept for API compatibility) + var expectedResultRewriter = new NUnitExpectedResultRewriter(null!); + compilationUnit = (CompilationUnitSyntax)expectedResultRewriter.Visit(compilationUnit); + + // Handle [ExpectedException] attribute conversion + var expectedExceptionRewriter = new NUnitExpectedExceptionRewriter(); + compilationUnit = (CompilationUnitSyntax)expectedExceptionRewriter.Visit(compilationUnit); + + // Handle Assert.Multiple(() => { ... }) → using (Assert.Multiple()) { ... } + var assertMultipleRewriter = new NUnitAssertMultipleRewriter(); + compilationUnit = (CompilationUnitSyntax)assertMultipleRewriter.Visit(compilationUnit); + + return compilationUnit; + } } public class NUnitAttributeRewriter : AttributeRewriter @@ -2161,11 +2188,34 @@ public class NUnitExpectedExceptionRewriter : CSharpSyntaxRewriter return node.WithAttributeLists(SyntaxFactory.List(newAttributeLists)); } - return node + // Check if the method is already async + var isAlreadyAsync = node.Modifiers.Any(m => m.IsKind(SyntaxKind.AsyncKeyword)); + + var result = node .WithAttributeLists(SyntaxFactory.List(newAttributeLists)) .WithBody(newBody) .WithExpressionBody(null) .WithSemicolonToken(default); + + // If not already async, make it async Task + if (!isAlreadyAsync) + { + // Add async modifier + var asyncModifier = SyntaxFactory.Token(SyntaxKind.AsyncKeyword).WithTrailingTrivia(SyntaxFactory.Space); + var newModifiers = result.Modifiers.Add(asyncModifier); + result = result.WithModifiers(newModifiers); + + // Change return type from void to Task + if (result.ReturnType is PredefinedTypeSyntax predefined && + predefined.Keyword.IsKind(SyntaxKind.VoidKeyword)) + { + var taskReturnType = SyntaxFactory.IdentifierName("Task") + .WithTrailingTrivia(result.ReturnType.GetTrailingTrivia()); + result = result.WithReturnType(taskReturnType); + } + } + + return result; } private static TypeSyntax? ExtractExceptionType(AttributeSyntax attribute) diff --git a/TUnit.Analyzers.CodeFixers/TwoPhase/MSTestTwoPhaseAnalyzer.cs b/TUnit.Analyzers.CodeFixers/TwoPhase/MSTestTwoPhaseAnalyzer.cs new file mode 100644 index 0000000000..f71a922370 --- /dev/null +++ b/TUnit.Analyzers.CodeFixers/TwoPhase/MSTestTwoPhaseAnalyzer.cs @@ -0,0 +1,875 @@ +using System.Collections.Generic; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using TUnit.Analyzers.CodeFixers.Base.TwoPhase; +using TUnit.Analyzers.Migrators.Base; + +namespace TUnit.Analyzers.CodeFixers.TwoPhase; + +/// +/// Phase 1 analyzer for MSTest to TUnit migration. +/// Collects all conversion targets while the semantic model is valid. +/// +public class MSTestTwoPhaseAnalyzer : MigrationAnalyzer +{ + private static readonly HashSet MSTestAssertMethods = new() + { + "AreEqual", "AreNotEqual", "AreSame", "AreNotSame", + "IsTrue", "IsFalse", + "IsNull", "IsNotNull", + "IsInstanceOfType", "IsNotInstanceOfType", + "ThrowsException", "ThrowsExceptionAsync", + "Fail", "Inconclusive" + }; + + private static readonly HashSet MSTestCollectionAssertMethods = new() + { + "AreEqual", "AreNotEqual", "AreEquivalent", "AreNotEquivalent", + "Contains", "DoesNotContain", + "IsSubsetOf", "IsNotSubsetOf", + "AllItemsAreUnique", "AllItemsAreNotNull", "AllItemsAreInstancesOfType" + }; + + private static readonly HashSet MSTestStringAssertMethods = new() + { + "Contains", "StartsWith", "EndsWith", "Matches", "DoesNotMatch" + }; + + private static readonly HashSet MSTestFileAssertMethods = new() + { + "Exists", "DoesNotExist" + }; + + private static readonly HashSet MSTestDirectoryAssertMethods = new() + { + "Exists", "DoesNotExist" + }; + + private static readonly HashSet MSTestAttributeNames = new() + { + "TestClass", "TestMethod", "DataRow", "DynamicData", + "TestInitialize", "TestCleanup", "ClassInitialize", "ClassCleanup", + "TestCategory", "Ignore", "Priority", "Owner", "ExpectedException" + }; + + private static readonly HashSet MSTestRemovableAttributeNames = new() + { + "TestClass" // TestClass is implicit in TUnit + }; + + public MSTestTwoPhaseAnalyzer(SemanticModel semanticModel, Compilation compilation) + : base(semanticModel, compilation) + { + } + + protected override IEnumerable FindAssertionNodes(CompilationUnitSyntax root) + { + return root.DescendantNodes() + .OfType() + .Where(IsMSTestAssertion); + } + + private bool IsMSTestAssertion(InvocationExpressionSyntax invocation) + { + if (invocation.Expression is MemberAccessExpressionSyntax memberAccess) + { + if (memberAccess.Expression is IdentifierNameSyntax identifier) + { + var typeName = identifier.Identifier.Text; + var methodName = memberAccess.Name.Identifier.Text; + + // Check Assert methods + if (typeName == "Assert" && MSTestAssertMethods.Contains(methodName)) + { + return VerifyMSTestNamespace(invocation); + } + + // Check CollectionAssert methods + if (typeName == "CollectionAssert" && MSTestCollectionAssertMethods.Contains(methodName)) + { + return VerifyMSTestNamespace(invocation); + } + + // Check StringAssert methods + if (typeName == "StringAssert" && MSTestStringAssertMethods.Contains(methodName)) + { + return VerifyMSTestNamespace(invocation); + } + + // Check FileAssert methods + if (typeName == "FileAssert" && MSTestFileAssertMethods.Contains(methodName)) + { + return VerifyMSTestNamespace(invocation); + } + + // Check DirectoryAssert methods + if (typeName == "DirectoryAssert" && MSTestDirectoryAssertMethods.Contains(methodName)) + { + return VerifyMSTestNamespace(invocation); + } + } + } + + return false; + } + + private bool VerifyMSTestNamespace(InvocationExpressionSyntax invocation) + { + try + { + var symbolInfo = SemanticModel.GetSymbolInfo(invocation); + if (symbolInfo.Symbol is IMethodSymbol methodSymbol) + { + var containingNamespace = methodSymbol.ContainingType?.ContainingNamespace?.ToDisplayString(); + return containingNamespace?.StartsWith("Microsoft.VisualStudio.TestTools.UnitTesting") == true; + } + } + catch + { + // Fall back to syntax-based detection + } + + return true; // Assume it's MSTest if syntax matches + } + + protected override AssertionConversion? AnalyzeAssertion(InvocationExpressionSyntax node) + { + if (node.Expression is not MemberAccessExpressionSyntax memberAccess) + return null; + + var typeName = (memberAccess.Expression as IdentifierNameSyntax)?.Identifier.Text ?? ""; + var methodName = memberAccess.Name.Identifier.Text; + var arguments = node.ArgumentList.Arguments; + + var (kind, replacementCode, introducesAwait, todoComment) = typeName switch + { + "Assert" => ConvertAssertMethod(methodName, arguments, memberAccess), + "CollectionAssert" => ConvertCollectionAssertMethod(methodName, arguments), + "StringAssert" => ConvertStringAssertMethod(methodName, arguments), + "FileAssert" => ConvertFileAssertMethod(methodName, arguments), + "DirectoryAssert" => ConvertDirectoryAssertMethod(methodName, arguments), + _ => (AssertionConversionKind.Unknown, null, false, null) + }; + + if (replacementCode == null) + return null; + + return new AssertionConversion + { + Kind = kind, + OriginalText = node.ToString(), + ReplacementCode = replacementCode, + IntroducesAwait = introducesAwait, + TodoComment = todoComment + }; + } + + private (AssertionConversionKind, string?, bool, string?) ConvertAssertMethod( + string methodName, SeparatedSyntaxList args, MemberAccessExpressionSyntax memberAccess) + { + return methodName switch + { + "AreEqual" => ConvertAreEqual(args), + "AreNotEqual" => ConvertAreNotEqual(args), + "AreSame" => ConvertAreSame(args), + "AreNotSame" => ConvertAreNotSame(args), + "IsTrue" => ConvertIsTrue(args), + "IsFalse" => ConvertIsFalse(args), + "IsNull" => ConvertIsNull(args), + "IsNotNull" => ConvertIsNotNull(args), + "IsInstanceOfType" => ConvertIsInstanceOfType(args), + "IsNotInstanceOfType" => ConvertIsNotInstanceOfType(args), + "ThrowsException" => ConvertThrowsException(memberAccess, args), + "ThrowsExceptionAsync" => ConvertThrowsExceptionAsync(memberAccess, args), + "Fail" => ConvertFail(args), + "Inconclusive" => ConvertInconclusive(args), + _ => (AssertionConversionKind.Unknown, null, false, null) + }; + } + + private (AssertionConversionKind, string?, bool, string?) ConvertAreEqual(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.AreEqual, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + + // Check for message parameter (3rd or later) + var (message, hasComparer) = GetMessageWithFormatArgs(args, 2); + + string? todoComment = hasComparer + ? "// TODO: TUnit migration - IEqualityComparer was used. TUnit uses .IsEqualTo() which may have different comparison semantics." + : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsEqualTo({expected}).Because({message})" + : $"await Assert.That({actual}).IsEqualTo({expected})"; + + return (AssertionConversionKind.AreEqual, assertion, true, todoComment); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertAreNotEqual(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.AreNotEqual, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + var (message, hasComparer) = GetMessageWithFormatArgs(args, 2); + + string? todoComment = hasComparer + ? "// TODO: TUnit migration - IEqualityComparer was used. TUnit uses .IsNotEqualTo() which may have different comparison semantics." + : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsNotEqualTo({expected}).Because({message})" + : $"await Assert.That({actual}).IsNotEqualTo({expected})"; + + return (AssertionConversionKind.AreNotEqual, assertion, true, todoComment); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertAreSame(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.AreSame, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsSameReferenceAs({expected}).Because({message})" + : $"await Assert.That({actual}).IsSameReferenceAs({expected})"; + + return (AssertionConversionKind.AreSame, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertAreNotSame(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.AreNotSame, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsNotSameReferenceAs({expected}).Because({message})" + : $"await Assert.That({actual}).IsNotSameReferenceAs({expected})"; + + return (AssertionConversionKind.AreNotSame, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertIsTrue(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.IsTrue, null, false, null); + + var value = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args, 1) : null; + + var assertion = message != null + ? $"await Assert.That({value}).IsTrue().Because({message})" + : $"await Assert.That({value}).IsTrue()"; + + return (AssertionConversionKind.IsTrue, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertIsFalse(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.IsFalse, null, false, null); + + var value = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args, 1) : null; + + var assertion = message != null + ? $"await Assert.That({value}).IsFalse().Because({message})" + : $"await Assert.That({value}).IsFalse()"; + + return (AssertionConversionKind.IsFalse, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertIsNull(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.IsNull, null, false, null); + + var value = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args, 1) : null; + + var assertion = message != null + ? $"await Assert.That({value}).IsNull().Because({message})" + : $"await Assert.That({value}).IsNull()"; + + return (AssertionConversionKind.IsNull, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertIsNotNull(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.IsNotNull, null, false, null); + + var value = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args, 1) : null; + + var assertion = message != null + ? $"await Assert.That({value}).IsNotNull().Because({message})" + : $"await Assert.That({value}).IsNotNull()"; + + return (AssertionConversionKind.IsNotNull, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertIsInstanceOfType(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.IsInstanceOfType, null, false, null); + + var value = args[0].Expression.ToString(); + var expectedType = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({value}).IsAssignableTo({expectedType}).Because({message})" + : $"await Assert.That({value}).IsAssignableTo({expectedType})"; + + return (AssertionConversionKind.IsInstanceOfType, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertIsNotInstanceOfType(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.IsNotInstanceOfType, null, false, null); + + var value = args[0].Expression.ToString(); + var expectedType = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({value}).IsNotAssignableTo({expectedType}).Because({message})" + : $"await Assert.That({value}).IsNotAssignableTo({expectedType})"; + + return (AssertionConversionKind.IsNotInstanceOfType, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertThrowsException( + MemberAccessExpressionSyntax memberAccess, SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.ThrowsException, null, false, null); + + var action = args[0].Expression.ToString(); + + // Get generic type argument if present + string typeArg = "Exception"; + if (memberAccess.Name is GenericNameSyntax genericName && + genericName.TypeArgumentList.Arguments.Count > 0) + { + typeArg = genericName.TypeArgumentList.Arguments[0].ToString(); + } + + var assertion = $"await Assert.ThrowsAsync<{typeArg}>({action})"; + return (AssertionConversionKind.ThrowsException, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertThrowsExceptionAsync( + MemberAccessExpressionSyntax memberAccess, SeparatedSyntaxList args) + { + // Same conversion as ThrowsException + return ConvertThrowsException(memberAccess, args); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertFail(SeparatedSyntaxList args) + { + var message = args.Count > 0 ? args[0].Expression.ToString() : "\"\""; + var assertion = $"await Assert.Fail({message})"; + return (AssertionConversionKind.Fail, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertInconclusive(SeparatedSyntaxList args) + { + var message = args.Count > 0 ? args[0].Expression.ToString() : "\"Test inconclusive\""; + var assertion = $"await Assert.Skip({message})"; + return (AssertionConversionKind.Inconclusive, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionAssertMethod( + string methodName, SeparatedSyntaxList args) + { + return methodName switch + { + "AreEqual" => ConvertCollectionAreEqual(args), + "AreNotEqual" => ConvertCollectionAreNotEqual(args), + "AreEquivalent" => ConvertCollectionAreEquivalent(args), + "AreNotEquivalent" => ConvertCollectionAreNotEquivalent(args), + "Contains" => ConvertCollectionContains(args), + "DoesNotContain" => ConvertCollectionDoesNotContain(args), + "IsSubsetOf" => ConvertCollectionIsSubsetOf(args), + "IsNotSubsetOf" => ConvertCollectionIsNotSubsetOf(args), + "AllItemsAreUnique" => ConvertCollectionAllItemsAreUnique(args), + "AllItemsAreNotNull" => ConvertCollectionAllItemsAreNotNull(args), + "AllItemsAreInstancesOfType" => ConvertCollectionAllItemsAreInstancesOfType(args), + _ => (AssertionConversionKind.Unknown, null, false, null) + }; + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionAreEqual(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.CollectionAreEqual, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsEquivalentTo({expected}).Because({message})" + : $"await Assert.That({actual}).IsEquivalentTo({expected})"; + + return (AssertionConversionKind.CollectionAreEqual, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionAreNotEqual(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.CollectionAreNotEqual, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsNotEquivalentTo({expected}).Because({message})" + : $"await Assert.That({actual}).IsNotEquivalentTo({expected})"; + + return (AssertionConversionKind.CollectionAreNotEqual, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionAreEquivalent(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.CollectionAreEquivalent, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsEquivalentTo({expected}).Because({message})" + : $"await Assert.That({actual}).IsEquivalentTo({expected})"; + + return (AssertionConversionKind.CollectionAreEquivalent, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionAreNotEquivalent(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.CollectionAreNotEquivalent, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsNotEquivalentTo({expected}).Because({message})" + : $"await Assert.That({actual}).IsNotEquivalentTo({expected})"; + + return (AssertionConversionKind.CollectionAreNotEquivalent, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionContains(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.CollectionContains, null, false, null); + + var collection = args[0].Expression.ToString(); + var element = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({collection}).Contains({element}).Because({message})" + : $"await Assert.That({collection}).Contains({element})"; + + return (AssertionConversionKind.CollectionContains, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionDoesNotContain(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.CollectionDoesNotContain, null, false, null); + + var collection = args[0].Expression.ToString(); + var element = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({collection}).DoesNotContain({element}).Because({message})" + : $"await Assert.That({collection}).DoesNotContain({element})"; + + return (AssertionConversionKind.CollectionDoesNotContain, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionIsSubsetOf(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.CollectionIsSubsetOf, null, false, null); + + var subset = args[0].Expression.ToString(); + var superset = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({subset}).IsSubsetOf({superset}).Because({message})" + : $"await Assert.That({subset}).IsSubsetOf({superset})"; + + return (AssertionConversionKind.CollectionIsSubsetOf, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionIsNotSubsetOf(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.CollectionIsNotSubsetOf, null, false, null); + + var subset = args[0].Expression.ToString(); + var superset = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({subset}).IsNotSubsetOf({superset}).Because({message})" + : $"await Assert.That({subset}).IsNotSubsetOf({superset})"; + + return (AssertionConversionKind.CollectionIsNotSubsetOf, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionAllItemsAreUnique(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.CollectionAllItemsAreUnique, null, false, null); + + var collection = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args, 1) : null; + + var assertion = message != null + ? $"await Assert.That({collection}).HasDistinctItems().Because({message})" + : $"await Assert.That({collection}).HasDistinctItems()"; + + return (AssertionConversionKind.CollectionAllItemsAreUnique, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionAllItemsAreNotNull(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.CollectionAllItemsAreNotNull, null, false, null); + + var collection = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args, 1) : null; + + var assertion = message != null + ? $"await Assert.That({collection}).All(x => x != null).Because({message})" + : $"await Assert.That({collection}).All(x => x != null)"; + + return (AssertionConversionKind.CollectionAllItemsAreNotNull, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionAllItemsAreInstancesOfType(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.CollectionAllItemsAreInstancesOfType, null, false, null); + + var collection = args[0].Expression.ToString(); + var expectedType = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({collection}).All(x => {expectedType}.IsInstanceOfType(x)).Because({message})" + : $"await Assert.That({collection}).All(x => {expectedType}.IsInstanceOfType(x))"; + + return (AssertionConversionKind.CollectionAllItemsAreInstancesOfType, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertStringAssertMethod( + string methodName, SeparatedSyntaxList args) + { + return methodName switch + { + "Contains" => ConvertStringContains(args), + "StartsWith" => ConvertStringStartsWith(args), + "EndsWith" => ConvertStringEndsWith(args), + "Matches" => ConvertStringMatches(args), + "DoesNotMatch" => ConvertStringDoesNotMatch(args), + _ => (AssertionConversionKind.Unknown, null, false, null) + }; + } + + private (AssertionConversionKind, string?, bool, string?) ConvertStringContains(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.StringContains, null, false, null); + + var value = args[0].Expression.ToString(); + var substring = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({value}).Contains({substring}).Because({message})" + : $"await Assert.That({value}).Contains({substring})"; + + return (AssertionConversionKind.StringContains, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertStringStartsWith(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.StringStartsWith, null, false, null); + + var value = args[0].Expression.ToString(); + var prefix = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({value}).StartsWith({prefix}).Because({message})" + : $"await Assert.That({value}).StartsWith({prefix})"; + + return (AssertionConversionKind.StringStartsWith, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertStringEndsWith(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.StringEndsWith, null, false, null); + + var value = args[0].Expression.ToString(); + var suffix = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({value}).EndsWith({suffix}).Because({message})" + : $"await Assert.That({value}).EndsWith({suffix})"; + + return (AssertionConversionKind.StringEndsWith, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertStringMatches(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.StringMatches, null, false, null); + + var value = args[0].Expression.ToString(); + var pattern = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({value}).Matches({pattern}).Because({message})" + : $"await Assert.That({value}).Matches({pattern})"; + + return (AssertionConversionKind.StringMatches, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertStringDoesNotMatch(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.StringDoesNotMatch, null, false, null); + + var value = args[0].Expression.ToString(); + var pattern = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args, 2) : null; + + var assertion = message != null + ? $"await Assert.That({value}).DoesNotMatch({pattern}).Because({message})" + : $"await Assert.That({value}).DoesNotMatch({pattern})"; + + return (AssertionConversionKind.StringDoesNotMatch, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertFileAssertMethod( + string methodName, SeparatedSyntaxList args) + { + if (args.Count < 1) + return (AssertionConversionKind.Unknown, null, false, null); + + var fileArg = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args, 1) : null; + + // MSTest FileAssert methods take a FileInfo argument + // FileAssert.Exists(file) -> Assert.That(file.Exists).IsTrue() + // FileAssert.DoesNotExist(file) -> Assert.That(file.Exists).IsFalse() + return methodName switch + { + "Exists" => message != null + ? (AssertionConversionKind.True, $"await Assert.That({fileArg}.Exists).IsTrue().Because({message})", true, null) + : (AssertionConversionKind.True, $"await Assert.That({fileArg}.Exists).IsTrue()", true, null), + "DoesNotExist" => message != null + ? (AssertionConversionKind.False, $"await Assert.That({fileArg}.Exists).IsFalse().Because({message})", true, null) + : (AssertionConversionKind.False, $"await Assert.That({fileArg}.Exists).IsFalse()", true, null), + _ => (AssertionConversionKind.Unknown, null, false, $"// TODO: TUnit migration - FileAssert.{methodName} has no direct equivalent") + }; + } + + private (AssertionConversionKind, string?, bool, string?) ConvertDirectoryAssertMethod( + string methodName, SeparatedSyntaxList args) + { + if (args.Count < 1) + return (AssertionConversionKind.Unknown, null, false, null); + + var dirArg = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args, 1) : null; + + // MSTest DirectoryAssert methods take a DirectoryInfo argument + // DirectoryAssert.Exists(dir) -> Assert.That(dir.Exists).IsTrue() + // DirectoryAssert.DoesNotExist(dir) -> Assert.That(dir.Exists).IsFalse() + return methodName switch + { + "Exists" => message != null + ? (AssertionConversionKind.True, $"await Assert.That({dirArg}.Exists).IsTrue().Because({message})", true, null) + : (AssertionConversionKind.True, $"await Assert.That({dirArg}.Exists).IsTrue()", true, null), + "DoesNotExist" => message != null + ? (AssertionConversionKind.False, $"await Assert.That({dirArg}.Exists).IsFalse().Because({message})", true, null) + : (AssertionConversionKind.False, $"await Assert.That({dirArg}.Exists).IsFalse()", true, null), + _ => (AssertionConversionKind.Unknown, null, false, $"// TODO: TUnit migration - DirectoryAssert.{methodName} has no direct equivalent") + }; + } + + /// + /// Gets the message argument, handling format strings and detecting comparer usage. + /// Returns (message, hasComparer) where message is wrapped in string.Format if format args present. + /// + private static (string? message, bool hasComparer) GetMessageWithFormatArgs(SeparatedSyntaxList args, int startIndex) + { + if (args.Count <= startIndex) + return (null, false); + + var arg = args[startIndex]; + + // Check if it's named "message" + if (arg.NameColon?.Name.Identifier.Text == "message") + { + return (arg.Expression.ToString(), false); + } + + // Check if it's a string literal (potential format string or simple message) + if (arg.Expression is LiteralExpressionSyntax literal && + literal.IsKind(SyntaxKind.StringLiteralExpression)) + { + var messageString = arg.Expression.ToString(); + + // Check if there are format arguments after the message (4+ total args means format string) + if (args.Count > startIndex + 1) + { + // Collect format args + var formatArgs = new List(); + for (int i = startIndex + 1; i < args.Count; i++) + { + formatArgs.Add(args[i].Expression.ToString()); + } + // Wrap in string.Format + return ($"string.Format({messageString}, {string.Join(", ", formatArgs)})", false); + } + + return (messageString, false); + } + + // Check for interpolated string + if (arg.Expression is InterpolatedStringExpressionSyntax) + { + return (arg.Expression.ToString(), false); + } + + // Not a string - likely a comparer + return (null, true); + } + + /// + /// Gets a simple message argument (for methods that don't support format strings like CollectionAssert). + /// + private static string? GetMessageArgument(SeparatedSyntaxList args, int startIndex) + { + if (args.Count > startIndex) + { + var arg = args[startIndex]; + // Check if it's named "message" + if (arg.NameColon?.Name.Identifier.Text == "message") + { + return arg.Expression.ToString(); + } + // Check if it looks like a string + if (arg.Expression is LiteralExpressionSyntax literal && + literal.IsKind(SyntaxKind.StringLiteralExpression)) + { + return arg.Expression.ToString(); + } + if (arg.Expression is InterpolatedStringExpressionSyntax) + { + return arg.Expression.ToString(); + } + } + return null; + } + + protected override bool ShouldRemoveAttribute(AttributeSyntax node) + { + var name = MigrationHelpers.GetAttributeName(node); + return MSTestRemovableAttributeNames.Contains(name); + } + + protected override AttributeConversion? AnalyzeAttribute(AttributeSyntax node) + { + var name = MigrationHelpers.GetAttributeName(node); + + if (!MSTestAttributeNames.Contains(name)) + return null; + + var (newName, newArgs) = name switch + { + "TestMethod" => ("Test", null), + "DataRow" => ("Arguments", node.ArgumentList?.ToString()), + "DynamicData" => ("MethodDataSource", ConvertDynamicDataArgs(node)), + "TestInitialize" => ("Before", "(HookType.Test)"), + "TestCleanup" => ("After", "(HookType.Test)"), + "ClassInitialize" => ("Before", "(HookType.Class)"), + "ClassCleanup" => ("After", "(HookType.Class)"), + "TestCategory" => ("Property", ConvertTestCategoryArgs(node)), + "Ignore" => ("Skip", node.ArgumentList?.ToString()), + "Priority" => ("Property", ConvertPriorityArgs(node)), + "Owner" => ("Property", ConvertOwnerArgs(node)), + "ExpectedException" => (null, null), // Handled separately + _ => (null, null) + }; + + if (newName == null) + return null; + + return new AttributeConversion + { + NewAttributeName = newName, + NewArgumentList = newArgs, + OriginalText = node.ToString() + }; + } + + private static string? ConvertDynamicDataArgs(AttributeSyntax node) + { + if (node.ArgumentList?.Arguments.Count > 0) + { + var firstArg = node.ArgumentList.Arguments[0]; + return $"({firstArg.Expression})"; + } + return null; + } + + private static string? ConvertTestCategoryArgs(AttributeSyntax node) + { + if (node.ArgumentList?.Arguments.Count > 0) + { + var value = node.ArgumentList.Arguments[0].Expression.ToString(); + return $"(\"Category\", {value})"; + } + return null; + } + + private static string? ConvertPriorityArgs(AttributeSyntax node) + { + if (node.ArgumentList?.Arguments.Count > 0) + { + var value = node.ArgumentList.Arguments[0].Expression.ToString(); + return $"(\"Priority\", \"{value}\")"; + } + return null; + } + + private static string? ConvertOwnerArgs(AttributeSyntax node) + { + if (node.ArgumentList?.Arguments.Count > 0) + { + var value = node.ArgumentList.Arguments[0].Expression.ToString(); + return $"(\"Owner\", {value})"; + } + return null; + } + + protected override bool ShouldRemoveBaseType(BaseTypeSyntax baseType) + { + // MSTest doesn't have common base types to remove + return false; + } + + protected override void AnalyzeUsings() + { + Plan.UsingPrefixesToRemove.Add("Microsoft.VisualStudio.TestTools"); + // TUnit usings are handled automatically by MigrationHelpers + } +} diff --git a/TUnit.Analyzers.CodeFixers/TwoPhase/NUnitTwoPhaseAnalyzer.cs b/TUnit.Analyzers.CodeFixers/TwoPhase/NUnitTwoPhaseAnalyzer.cs new file mode 100644 index 0000000000..5feb1de6af --- /dev/null +++ b/TUnit.Analyzers.CodeFixers/TwoPhase/NUnitTwoPhaseAnalyzer.cs @@ -0,0 +1,1926 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using TUnit.Analyzers.CodeFixers.Base.TwoPhase; +using TUnit.Analyzers.Migrators.Base; + +namespace TUnit.Analyzers.CodeFixers.TwoPhase; + +/// +/// Phase 1 analyzer for NUnit to TUnit migration. +/// Collects all conversion targets while the semantic model is valid. +/// +public class NUnitTwoPhaseAnalyzer : MigrationAnalyzer +{ + private static readonly HashSet NUnitClassicAssertMethods = new() + { + "AreEqual", "AreNotEqual", "AreSame", "AreNotSame", + "IsTrue", "IsFalse", + "IsNull", "IsNotNull", + "IsEmpty", "IsNotEmpty", + "IsInstanceOf", "IsNotInstanceOf", + "IsAssignableFrom", "IsNotAssignableFrom", + "Greater", "GreaterOrEqual", "Less", "LessOrEqual", + "Contains", "DoesNotContain", + "Throws", "ThrowsAsync", "DoesNotThrow", "DoesNotThrowAsync", + "Pass", "Fail", "Inconclusive", "Ignore", "Warn", + "Positive", "Negative", "Zero", "NotZero", + "Catch", "CatchAsync" + }; + + private static readonly HashSet NUnitAttributeNames = new() + { + "Test", "Theory", "TestCase", "TestCaseSource", + "SetUp", "TearDown", "OneTimeSetUp", "OneTimeTearDown", + "TestFixture", "Category", "Ignore", "Explicit", + "Description", "Author", "Apartment", + "Parallelizable", "NonParallelizable", + "Repeat", "Values", "Range", "ValueSource", + "Sequential", "Combinatorial", "Platform", + "ExpectedException" + }; + + private static readonly HashSet NUnitRemovableAttributeNames = new() + { + "TestFixture", // TestFixture is implicit in TUnit + "Combinatorial", // TUnit's default behavior is combinatorial + "Sequential" // No direct equivalent - TUnit uses Matrix which is combinatorial by default + }; + + private static readonly HashSet NUnitConditionallyRemovableAttributes = new() + { + "Parallelizable" // Only removed when NOT ParallelScope.None + }; + + public NUnitTwoPhaseAnalyzer(SemanticModel semanticModel, Compilation compilation) + : base(semanticModel, compilation) + { + } + + protected override IEnumerable FindAssertionNodes(CompilationUnitSyntax root) + { + return root.DescendantNodes() + .OfType() + .Where(IsNUnitAssertion); + } + + private bool IsNUnitAssertion(InvocationExpressionSyntax invocation) + { + if (invocation.Expression is MemberAccessExpressionSyntax memberAccess) + { + var typeName = GetSimpleTypeName(memberAccess.Expression); + var methodName = memberAccess.Name.Identifier.Text; + + // Check classic Assert methods + if (typeName is "Assert" or "ClassicAssert") + { + // Handle Assert.That (constraint-based) separately + if (methodName == "That") + return true; + + if (NUnitClassicAssertMethods.Contains(methodName)) + return VerifyNUnitNamespace(invocation); + } + + // Check StringAssert, CollectionAssert, FileAssert, DirectoryAssert + if (typeName is "StringAssert" or "CollectionAssert" or "FileAssert" or "DirectoryAssert") + return VerifyNUnitNamespace(invocation); + } + + return false; + } + + private static string GetSimpleTypeName(ExpressionSyntax expression) + { + return expression switch + { + IdentifierNameSyntax identifier => identifier.Identifier.Text, + MemberAccessExpressionSyntax memberAccess => memberAccess.Name.Identifier.Text, + _ => expression.ToString() + }; + } + + private bool VerifyNUnitNamespace(InvocationExpressionSyntax invocation) + { + try + { + var symbolInfo = SemanticModel.GetSymbolInfo(invocation); + if (symbolInfo.Symbol is IMethodSymbol methodSymbol) + { + var containingNamespace = methodSymbol.ContainingType?.ContainingNamespace?.ToDisplayString(); + return containingNamespace?.StartsWith("NUnit.Framework") == true; + } + } + catch + { + // Fall back to syntax-based detection + } + + return true; // Assume it's NUnit if syntax matches + } + + protected override AssertionConversion? AnalyzeAssertion(InvocationExpressionSyntax node) + { + if (node.Expression is not MemberAccessExpressionSyntax memberAccess) + return null; + + var typeName = GetSimpleTypeName(memberAccess.Expression); + var methodName = memberAccess.Name.Identifier.Text; + var arguments = node.ArgumentList.Arguments; + + // Handle constraint-based Assert.That + if (typeName == "Assert" && methodName == "That") + { + return ConvertAssertThat(node, arguments, memberAccess); + } + + // Handle classic assertions (Assert.* or ClassicAssert.*) + if (typeName is "Assert" or "ClassicAssert") + { + var (kind, replacementCode, introducesAwait, todoComment) = ConvertClassicAssert(methodName, arguments, memberAccess); + if (replacementCode == null) + return null; + + return new AssertionConversion + { + Kind = kind, + OriginalText = node.ToString(), + ReplacementCode = replacementCode, + IntroducesAwait = introducesAwait, + TodoComment = todoComment + }; + } + + // Handle StringAssert + if (typeName == "StringAssert") + { + var (kind, replacementCode, introducesAwait, todoComment) = ConvertStringAssert(methodName, arguments); + if (replacementCode == null) + return null; + + return new AssertionConversion + { + Kind = kind, + OriginalText = node.ToString(), + ReplacementCode = replacementCode, + IntroducesAwait = introducesAwait, + TodoComment = todoComment + }; + } + + // Handle CollectionAssert + if (typeName == "CollectionAssert") + { + var (kind, replacementCode, introducesAwait, todoComment) = ConvertCollectionAssert(methodName, arguments); + if (replacementCode == null) + return null; + + return new AssertionConversion + { + Kind = kind, + OriginalText = node.ToString(), + ReplacementCode = replacementCode, + IntroducesAwait = introducesAwait, + TodoComment = todoComment + }; + } + + // Handle FileAssert + if (typeName == "FileAssert") + { + var (kind, replacementCode, introducesAwait, todoComment) = ConvertFileAssert(methodName, arguments); + if (replacementCode == null) + return null; + + return new AssertionConversion + { + Kind = kind, + OriginalText = node.ToString(), + ReplacementCode = replacementCode, + IntroducesAwait = introducesAwait, + TodoComment = todoComment + }; + } + + // Handle DirectoryAssert + if (typeName == "DirectoryAssert") + { + var (kind, replacementCode, introducesAwait, todoComment) = ConvertDirectoryAssert(methodName, arguments); + if (replacementCode == null) + return null; + + return new AssertionConversion + { + Kind = kind, + OriginalText = node.ToString(), + ReplacementCode = replacementCode, + IntroducesAwait = introducesAwait, + TodoComment = todoComment + }; + } + + return null; + } + + private AssertionConversion? ConvertAssertThat( + InvocationExpressionSyntax node, + SeparatedSyntaxList args, + MemberAccessExpressionSyntax memberAccess) + { + if (args.Count < 2) + return null; + + var actualValue = args[0].Expression.ToString(); + var constraintArg = args[1].Expression; + + // Handle common constraint patterns + var (kind, assertionSuffix, todoComment) = AnalyzeConstraint(constraintArg); + + if (assertionSuffix == null) + { + // Complex constraint - leave a TODO + return new AssertionConversion + { + Kind = AssertionConversionKind.Unknown, + OriginalText = node.ToString(), + ReplacementCode = node.ToString(), + IntroducesAwait = false, + TodoComment = "// TODO: TUnit migration - Complex NUnit constraint. Manual conversion required." + }; + } + + // Check for message parameter + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actualValue}){assertionSuffix}.Because({message})" + : $"await Assert.That({actualValue}){assertionSuffix}"; + + return new AssertionConversion + { + Kind = kind, + OriginalText = node.ToString(), + ReplacementCode = assertion, + IntroducesAwait = true, + TodoComment = todoComment + }; + } + + private (AssertionConversionKind, string?, string?) AnalyzeConstraint(ExpressionSyntax constraint) + { + // Handle Is.EqualTo(expected) + if (constraint is InvocationExpressionSyntax invocation) + { + if (invocation.Expression is MemberAccessExpressionSyntax memberAccess) + { + var receiverText = memberAccess.Expression.ToString(); + var methodName = memberAccess.Name.Identifier.Text; + var args = invocation.ArgumentList.Arguments; + + // Handle chained constraint modifiers like .Within(delta) on Is.EqualTo(x).Within(delta) + if (methodName == "Within" && memberAccess.Expression is InvocationExpressionSyntax innerConstraint) + { + // Get the base assertion (e.g., IsEqualTo(5)) first + var (kind, baseAssertion, todoComment) = AnalyzeConstraint(innerConstraint); + if (baseAssertion != null && args.Count >= 1) + { + var delta = args[0].Expression.ToString(); + return (kind, $"{baseAssertion}.Within({delta})", todoComment); + } + } + + // Is.EqualTo(expected) + if (receiverText == "Is" && methodName == "EqualTo" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.Equal, $".IsEqualTo({expected})", null); + } + + // Is.Not.EqualTo(expected) + if (receiverText == "Is.Not" && methodName == "EqualTo" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.NotEqual, $".IsNotEqualTo({expected})", null); + } + + // Is.SameAs(expected) + if (receiverText == "Is" && methodName == "SameAs" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.Same, $".IsSameReferenceAs({expected})", null); + } + + // Is.Not.SameAs(expected) + if (receiverText == "Is.Not" && methodName == "SameAs" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.NotSame, $".IsNotSameReferenceAs({expected})", null); + } + + // Is.InstanceOf() or Is.InstanceOf(type) + // InstanceOf checks if the value is an instance of T or a derived type, maps to IsAssignableTo + if (receiverText == "Is" && methodName == "InstanceOf") + { + if (memberAccess.Name is GenericNameSyntax genericName && genericName.TypeArgumentList.Arguments.Count > 0) + { + var typeArg = genericName.TypeArgumentList.Arguments[0].ToString(); + return (AssertionConversionKind.IsAssignableFrom, $".IsAssignableTo<{typeArg}>()", null); + } + if (args.Count >= 1) + { + var typeArg = args[0].Expression.ToString(); + return (AssertionConversionKind.IsAssignableFrom, $".IsAssignableTo({typeArg})", null); + } + } + + // Is.TypeOf() + if (receiverText == "Is" && methodName == "TypeOf") + { + if (memberAccess.Name is GenericNameSyntax genericName && genericName.TypeArgumentList.Arguments.Count > 0) + { + var typeArg = genericName.TypeArgumentList.Arguments[0].ToString(); + return (AssertionConversionKind.IsType, $".IsTypeOf<{typeArg}>()", null); + } + } + + // Is.AssignableTo() + if (receiverText == "Is" && methodName == "AssignableTo") + { + if (memberAccess.Name is GenericNameSyntax genericName && genericName.TypeArgumentList.Arguments.Count > 0) + { + var typeArg = genericName.TypeArgumentList.Arguments[0].ToString(); + return (AssertionConversionKind.IsAssignableFrom, $".IsAssignableTo<{typeArg}>()", null); + } + } + + // Is.Not.InstanceOf() + if (receiverText == "Is.Not" && methodName == "InstanceOf") + { + if (memberAccess.Name is GenericNameSyntax genericName && genericName.TypeArgumentList.Arguments.Count > 0) + { + var typeArg = genericName.TypeArgumentList.Arguments[0].ToString(); + return (AssertionConversionKind.IsNotType, $".IsNotAssignableTo<{typeArg}>()", null); + } + } + + // Is.Not.TypeOf() + if (receiverText == "Is.Not" && methodName == "TypeOf") + { + if (memberAccess.Name is GenericNameSyntax genericName && genericName.TypeArgumentList.Arguments.Count > 0) + { + var typeArg = genericName.TypeArgumentList.Arguments[0].ToString(); + return (AssertionConversionKind.IsNotType, $".IsNotTypeOf<{typeArg}>()", null); + } + } + + // Does.StartWith(string) + if (receiverText == "Does" && methodName == "StartWith" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.StartsWith, $".StartsWith({expected})", null); + } + + // Does.EndWith(string) + if (receiverText == "Does" && methodName == "EndWith" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.EndsWith, $".EndsWith({expected})", null); + } + + // Does.Contain(string) + if (receiverText == "Does" && methodName == "Contain" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.Contains, $".Contains({expected})", null); + } + + // Does.Not.StartWith(string) + if (receiverText == "Does.Not" && methodName == "StartWith" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.StartsWith, $".DoesNotStartWith({expected})", null); + } + + // Does.Not.EndWith(string) + if (receiverText == "Does.Not" && methodName == "EndWith" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.EndsWith, $".DoesNotEndWith({expected})", null); + } + + // Does.Not.Contain(string) - for strings + if (receiverText == "Does.Not" && methodName == "Contain" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.DoesNotContain, $".DoesNotContain({expected})", null); + } + + // Does.Match(pattern) + if (receiverText == "Does" && methodName == "Match" && args.Count >= 1) + { + var pattern = args[0].Expression.ToString(); + return (AssertionConversionKind.Matches, $".Matches({pattern})", null); + } + + // Does.Not.Match(pattern) + if (receiverText == "Does.Not" && methodName == "Match" && args.Count >= 1) + { + var pattern = args[0].Expression.ToString(); + return (AssertionConversionKind.Matches, $".DoesNotMatch({pattern})", null); + } + + // Has.Count.EqualTo(n) + if (receiverText == "Has.Count" && methodName == "EqualTo" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.Equal, $".Count().IsEqualTo({expected})", null); + } + + // Has.Member(item) + if (receiverText == "Has" && methodName == "Member" && args.Count >= 1) + { + var item = args[0].Expression.ToString(); + return (AssertionConversionKind.Contains, $".Contains({item})", null); + } + + // Has.Exactly(n).Items + if (memberAccess.Expression is InvocationExpressionSyntax hasExactlyInvocation && + hasExactlyInvocation.Expression is MemberAccessExpressionSyntax hasExactlyAccess && + hasExactlyAccess.Expression.ToString() == "Has" && + hasExactlyAccess.Name.Identifier.Text == "Exactly" && + methodName == "Items" && + hasExactlyInvocation.ArgumentList.Arguments.Count >= 1) + { + var count = hasExactlyInvocation.ArgumentList.Arguments[0].Expression.ToString(); + return (AssertionConversionKind.Collection, $".Count().IsEqualTo({count})", null); + } + + // Is.GreaterThan(expected) + if (receiverText == "Is" && methodName == "GreaterThan" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.InRange, $".IsGreaterThan({expected})", null); + } + + // Is.LessThan(expected) + if (receiverText == "Is" && methodName == "LessThan" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.InRange, $".IsLessThan({expected})", null); + } + + // Is.GreaterThanOrEqualTo(expected) + if (receiverText == "Is" && methodName == "GreaterThanOrEqualTo" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.InRange, $".IsGreaterThanOrEqualTo({expected})", null); + } + + // Is.LessThanOrEqualTo(expected) + if (receiverText == "Is" && methodName == "LessThanOrEqualTo" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.InRange, $".IsLessThanOrEqualTo({expected})", null); + } + + // Is.Not.GreaterThan(expected) → IsLessThanOrEqualTo + if (receiverText == "Is.Not" && methodName == "GreaterThan" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.InRange, $".IsLessThanOrEqualTo({expected})", null); + } + + // Is.Not.LessThan(expected) → IsGreaterThanOrEqualTo + if (receiverText == "Is.Not" && methodName == "LessThan" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.InRange, $".IsGreaterThanOrEqualTo({expected})", null); + } + + // Is.Not.GreaterThanOrEqualTo(expected) → IsLessThan + if (receiverText == "Is.Not" && methodName == "GreaterThanOrEqualTo" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.InRange, $".IsLessThan({expected})", null); + } + + // Is.Not.LessThanOrEqualTo(expected) → IsGreaterThan + if (receiverText == "Is.Not" && methodName == "LessThanOrEqualTo" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.InRange, $".IsGreaterThan({expected})", null); + } + + // Contains.Item(expected) + if ((receiverText == "Contains" || receiverText == "Does.Contain") && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.Contains, $".Contains({expected})", null); + } + + // Does.Not.Contain / Does.Not.Contain(expected) + if (receiverText == "Does.Not" && methodName == "Contain" && args.Count >= 1) + { + var expected = args[0].Expression.ToString(); + return (AssertionConversionKind.DoesNotContain, $".DoesNotContain({expected})", null); + } + + // Throws.TypeOf() + if (receiverText == "Throws" && methodName == "TypeOf") + { + if (memberAccess.Name is GenericNameSyntax genericThrows && genericThrows.TypeArgumentList.Arguments.Count > 0) + { + var exceptionType = genericThrows.TypeArgumentList.Arguments[0].ToString(); + return (AssertionConversionKind.Throws, null, $"// TODO: TUnit migration - Convert to Assert.ThrowsAsync<{exceptionType}>()"); + } + } + } + } + + // Handle simple property constraints + if (constraint is MemberAccessExpressionSyntax simpleMemberAccess) + { + var fullConstraint = simpleMemberAccess.ToString(); + + // Is.True + if (fullConstraint is "Is.True") + return (AssertionConversionKind.True, ".IsTrue()", null); + + // Is.False + if (fullConstraint is "Is.False") + return (AssertionConversionKind.False, ".IsFalse()", null); + + // Is.Null + if (fullConstraint is "Is.Null") + return (AssertionConversionKind.Null, ".IsNull()", null); + + // Is.Not.Null + if (fullConstraint is "Is.Not.Null") + return (AssertionConversionKind.NotNull, ".IsNotNull()", null); + + // Is.Empty + if (fullConstraint is "Is.Empty") + return (AssertionConversionKind.Empty, ".IsEmpty()", null); + + // Is.Not.Empty + if (fullConstraint is "Is.Not.Empty") + return (AssertionConversionKind.NotEmpty, ".IsNotEmpty()", null); + + // Is.Positive + if (fullConstraint is "Is.Positive") + return (AssertionConversionKind.InRange, ".IsPositive()", null); + + // Is.Negative + if (fullConstraint is "Is.Negative") + return (AssertionConversionKind.InRange, ".IsNegative()", null); + + // Is.Zero + if (fullConstraint is "Is.Zero") + return (AssertionConversionKind.Equal, ".IsZero()", null); + + // Is.Not.Zero + if (fullConstraint is "Is.Not.Zero") + return (AssertionConversionKind.NotEqual, ".IsNotZero()", null); + + // Is.Not.Positive - means value <= 0 + if (fullConstraint is "Is.Not.Positive") + return (AssertionConversionKind.InRange, ".IsLessThanOrEqualTo(0)", null); + + // Is.Not.Negative - means value >= 0 + if (fullConstraint is "Is.Not.Negative") + return (AssertionConversionKind.InRange, ".IsGreaterThanOrEqualTo(0)", null); + + // Is.NaN + if (fullConstraint is "Is.NaN") + return (AssertionConversionKind.Equal, ".IsNaN()", null); + + // Is.Not.NaN + if (fullConstraint is "Is.Not.NaN") + return (AssertionConversionKind.NotEqual, ".IsNotNaN()", null); + + // Is.Ordered (default ascending) - use generic .IsInOrder() + if (fullConstraint is "Is.Ordered") + return (AssertionConversionKind.Collection, ".IsInOrder()", null); + + // Is.Ordered.Ascending - use generic .IsInOrder() + if (fullConstraint is "Is.Ordered.Ascending") + return (AssertionConversionKind.Collection, ".IsInOrder()", null); + + // Is.Ordered.Descending + if (fullConstraint is "Is.Ordered.Descending") + return (AssertionConversionKind.Collection, ".IsInDescendingOrder()", null); + + // Is.Unique + if (fullConstraint is "Is.Unique") + return (AssertionConversionKind.Collection, ".HasDistinctItems()", null); + + // Has.Exactly(n).Items - property access pattern + if (simpleMemberAccess.Name.Identifier.Text == "Items" && + simpleMemberAccess.Expression is InvocationExpressionSyntax hasExactlyInvocation && + hasExactlyInvocation.Expression is MemberAccessExpressionSyntax hasExactlyAccess && + hasExactlyAccess.Expression.ToString() == "Has" && + hasExactlyAccess.Name.Identifier.Text == "Exactly" && + hasExactlyInvocation.ArgumentList.Arguments.Count >= 1) + { + var count = hasExactlyInvocation.ArgumentList.Arguments[0].Expression.ToString(); + return (AssertionConversionKind.Collection, $".Count().IsEqualTo({count})", null); + } + } + + // Unknown constraint + return (AssertionConversionKind.Unknown, null, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertClassicAssert( + string methodName, SeparatedSyntaxList args, MemberAccessExpressionSyntax memberAccess) + { + return methodName switch + { + "AreEqual" => ConvertAreEqual(args), + "AreNotEqual" => ConvertAreNotEqual(args), + "AreSame" => ConvertAreSame(args), + "AreNotSame" => ConvertAreNotSame(args), + "IsTrue" => ConvertIsTrue(args), + "IsFalse" => ConvertIsFalse(args), + "IsNull" => ConvertIsNull(args), + "IsNotNull" => ConvertIsNotNull(args), + "IsEmpty" => ConvertIsEmpty(args), + "IsNotEmpty" => ConvertIsNotEmpty(args), + "Greater" => ConvertGreater(args), + "GreaterOrEqual" => ConvertGreaterOrEqual(args), + "Less" => ConvertLess(args), + "LessOrEqual" => ConvertLessOrEqual(args), + "Contains" => ConvertContains(args), + "Pass" => (AssertionConversionKind.Skip, "// Test passed", false, null), + "Fail" when args.Count > 0 => (AssertionConversionKind.Fail, $"Fail.Test({args[0].Expression})", false, null), + "Fail" => (AssertionConversionKind.Fail, "Fail.Test(\"\")", false, null), + "Inconclusive" when args.Count > 0 => (AssertionConversionKind.Inconclusive, $"Skip.Test({args[0].Expression})", false, null), + "Inconclusive" => (AssertionConversionKind.Inconclusive, "Skip.Test(\"Test inconclusive\")", false, null), + "Ignore" when args.Count > 0 => (AssertionConversionKind.Skip, $"Skip.Test({args[0].Expression})", false, null), + "Ignore" => (AssertionConversionKind.Skip, "Skip.Test(\"Ignored\")", false, null), + "Throws" or "ThrowsAsync" => ConvertThrows(memberAccess, args), + "Catch" or "CatchAsync" => ConvertCatch(memberAccess, args), + "DoesNotThrow" or "DoesNotThrowAsync" => ConvertDoesNotThrow(args), + "Positive" when args.Count >= 1 => (AssertionConversionKind.InRange, $"await Assert.That({args[0].Expression}).IsPositive()", true, null), + "Negative" when args.Count >= 1 => (AssertionConversionKind.InRange, $"await Assert.That({args[0].Expression}).IsNegative()", true, null), + "Zero" when args.Count >= 1 => (AssertionConversionKind.Equal, $"await Assert.That({args[0].Expression}).IsZero()", true, null), + "NotZero" when args.Count >= 1 => (AssertionConversionKind.NotEqual, $"await Assert.That({args[0].Expression}).IsNotZero()", true, null), + "Warn" when args.Count > 0 => (AssertionConversionKind.Skip, $"Skip.Test($\"Warning: {{{args[0].Expression}}}\")", false, null), + "Warn" => (AssertionConversionKind.Skip, "Skip.Test(\"Warning\")", false, null), + _ => (AssertionConversionKind.Unknown, null, false, null) + }; + } + + private (AssertionConversionKind, string?, bool, string?) ConvertAreEqual(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.AreEqual, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsEqualTo({expected}).Because({message})" + : $"await Assert.That({actual}).IsEqualTo({expected})"; + + return (AssertionConversionKind.AreEqual, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertAreNotEqual(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.AreNotEqual, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsNotEqualTo({expected}).Because({message})" + : $"await Assert.That({actual}).IsNotEqualTo({expected})"; + + return (AssertionConversionKind.AreNotEqual, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertAreSame(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.AreSame, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsSameReferenceAs({expected}).Because({message})" + : $"await Assert.That({actual}).IsSameReferenceAs({expected})"; + + return (AssertionConversionKind.AreSame, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertAreNotSame(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.AreNotSame, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsNotSameReferenceAs({expected}).Because({message})" + : $"await Assert.That({actual}).IsNotSameReferenceAs({expected})"; + + return (AssertionConversionKind.AreNotSame, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertIsTrue(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.IsTrue, null, false, null); + + var value = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args[1]) : null; + + var assertion = message != null + ? $"await Assert.That({value}).IsTrue().Because({message})" + : $"await Assert.That({value}).IsTrue()"; + + return (AssertionConversionKind.IsTrue, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertIsFalse(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.IsFalse, null, false, null); + + var value = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args[1]) : null; + + var assertion = message != null + ? $"await Assert.That({value}).IsFalse().Because({message})" + : $"await Assert.That({value}).IsFalse()"; + + return (AssertionConversionKind.IsFalse, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertIsNull(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.IsNull, null, false, null); + + var value = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args[1]) : null; + + var assertion = message != null + ? $"await Assert.That({value}).IsNull().Because({message})" + : $"await Assert.That({value}).IsNull()"; + + return (AssertionConversionKind.IsNull, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertIsNotNull(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.IsNotNull, null, false, null); + + var value = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args[1]) : null; + + var assertion = message != null + ? $"await Assert.That({value}).IsNotNull().Because({message})" + : $"await Assert.That({value}).IsNotNull()"; + + return (AssertionConversionKind.IsNotNull, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertIsEmpty(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.Empty, null, false, null); + + var value = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args[1]) : null; + + var assertion = message != null + ? $"await Assert.That({value}).IsEmpty().Because({message})" + : $"await Assert.That({value}).IsEmpty()"; + + return (AssertionConversionKind.Empty, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertIsNotEmpty(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.NotEmpty, null, false, null); + + var value = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args[1]) : null; + + var assertion = message != null + ? $"await Assert.That({value}).IsNotEmpty().Because({message})" + : $"await Assert.That({value}).IsNotEmpty()"; + + return (AssertionConversionKind.NotEmpty, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertGreater(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.InRange, null, false, null); + + var actual = args[0].Expression.ToString(); + var expected = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsGreaterThan({expected}).Because({message})" + : $"await Assert.That({actual}).IsGreaterThan({expected})"; + + return (AssertionConversionKind.InRange, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertGreaterOrEqual(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.InRange, null, false, null); + + var actual = args[0].Expression.ToString(); + var expected = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsGreaterThanOrEqualTo({expected}).Because({message})" + : $"await Assert.That({actual}).IsGreaterThanOrEqualTo({expected})"; + + return (AssertionConversionKind.InRange, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertLess(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.InRange, null, false, null); + + var actual = args[0].Expression.ToString(); + var expected = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsLessThan({expected}).Because({message})" + : $"await Assert.That({actual}).IsLessThan({expected})"; + + return (AssertionConversionKind.InRange, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertLessOrEqual(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.InRange, null, false, null); + + var actual = args[0].Expression.ToString(); + var expected = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsLessThanOrEqualTo({expected}).Because({message})" + : $"await Assert.That({actual}).IsLessThanOrEqualTo({expected})"; + + return (AssertionConversionKind.InRange, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertContains(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.Contains, null, false, null); + + var expected = args[0].Expression.ToString(); + var collection = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({collection}).Contains({expected}).Because({message})" + : $"await Assert.That({collection}).Contains({expected})"; + + return (AssertionConversionKind.Contains, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertThrows( + MemberAccessExpressionSyntax memberAccess, SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.Throws, null, false, null); + + string typeArg = "Exception"; + string action; + + // Check if generic: Assert.Throws(action) + if (memberAccess.Name is GenericNameSyntax genericName && + genericName.TypeArgumentList.Arguments.Count > 0) + { + typeArg = genericName.TypeArgumentList.Arguments[0].ToString(); + action = args[0].Expression.ToString(); + } + // Check if constraint form: Assert.Throws(Is.TypeOf(typeof(T)), action) + else if (args.Count >= 2) + { + var constraintExpr = args[0].Expression; + action = args[1].Expression.ToString(); + + // Try to extract type from constraint + typeArg = TryExtractTypeFromThrowsConstraint(constraintExpr) ?? "Exception"; + } + else + { + action = args[0].Expression.ToString(); + } + + var assertion = $"await Assert.ThrowsAsync<{typeArg}>({action})"; + return (AssertionConversionKind.Throws, assertion, true, null); + } + + private static string? TryExtractTypeFromThrowsConstraint(ExpressionSyntax constraint) + { + // Handle Is.TypeOf(typeof(ArgumentException)) + if (constraint is InvocationExpressionSyntax invocation) + { + var invocationText = invocation.ToString(); + if (invocationText.StartsWith("Is.TypeOf(typeof(")) + { + // Extract type from Is.TypeOf(typeof(ArgumentException)) + var start = "Is.TypeOf(typeof(".Length; + var end = invocationText.LastIndexOf("))"); + if (end > start) + { + return invocationText.Substring(start, end - start); + } + } + // Handle Is.TypeOf() + if (invocation.Expression is MemberAccessExpressionSyntax memberAccess && + memberAccess.Name is GenericNameSyntax genericName && + genericName.Identifier.Text == "TypeOf" && + genericName.TypeArgumentList.Arguments.Count > 0) + { + return genericName.TypeArgumentList.Arguments[0].ToString(); + } + } + // Handle Is.InstanceOf() + if (constraint is InvocationExpressionSyntax instanceOfInvocation && + instanceOfInvocation.Expression is MemberAccessExpressionSyntax instanceOfAccess && + instanceOfAccess.Name is GenericNameSyntax instanceOfGeneric && + instanceOfGeneric.Identifier.Text == "InstanceOf" && + instanceOfGeneric.TypeArgumentList.Arguments.Count > 0) + { + return instanceOfGeneric.TypeArgumentList.Arguments[0].ToString(); + } + return null; + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCatch( + MemberAccessExpressionSyntax memberAccess, SeparatedSyntaxList args) + { + // Catch is similar to Throws but returns the exception + return ConvertThrows(memberAccess, args); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertDoesNotThrow(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.Throws, null, false, null); + + var action = args[0].Expression.ToString(); + + // DoesNotThrow is simply invoking the action - if it throws, the test fails + var assertion = $"await Assert.That({action}).ThrowsNothing()"; + return (AssertionConversionKind.Throws, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertStringAssert( + string methodName, SeparatedSyntaxList args) + { + return methodName switch + { + "Contains" when args.Count >= 2 => ConvertStringContains(args), + "StartsWith" when args.Count >= 2 => ConvertStringStartsWith(args), + "EndsWith" when args.Count >= 2 => ConvertStringEndsWith(args), + "AreEqualIgnoringCase" when args.Count >= 2 => ConvertStringAreEqualIgnoringCase(args), + "IsMatch" when args.Count >= 2 => ConvertStringMatches(args), + "DoesNotMatch" when args.Count >= 2 => ConvertStringDoesNotMatch(args), + _ => (AssertionConversionKind.Unknown, null, false, null) + }; + } + + private (AssertionConversionKind, string?, bool, string?) ConvertStringContains(SeparatedSyntaxList args) + { + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).Contains({expected}).Because({message})" + : $"await Assert.That({actual}).Contains({expected})"; + + return (AssertionConversionKind.StringContains, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertStringStartsWith(SeparatedSyntaxList args) + { + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).StartsWith({expected}).Because({message})" + : $"await Assert.That({actual}).StartsWith({expected})"; + + return (AssertionConversionKind.StringStartsWith, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertStringEndsWith(SeparatedSyntaxList args) + { + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).EndsWith({expected}).Because({message})" + : $"await Assert.That({actual}).EndsWith({expected})"; + + return (AssertionConversionKind.StringEndsWith, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertStringAreEqualIgnoringCase(SeparatedSyntaxList args) + { + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsEqualTo({expected}, StringComparison.OrdinalIgnoreCase).Because({message})" + : $"await Assert.That({actual}).IsEqualTo({expected}, StringComparison.OrdinalIgnoreCase)"; + + return (AssertionConversionKind.AreEqual, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertStringMatches(SeparatedSyntaxList args) + { + var pattern = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).Matches({pattern}).Because({message})" + : $"await Assert.That({actual}).Matches({pattern})"; + + return (AssertionConversionKind.StringMatches, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertStringDoesNotMatch(SeparatedSyntaxList args) + { + var pattern = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).DoesNotMatch({pattern}).Because({message})" + : $"await Assert.That({actual}).DoesNotMatch({pattern})"; + + return (AssertionConversionKind.StringDoesNotMatch, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionAssert( + string methodName, SeparatedSyntaxList args) + { + return methodName switch + { + "AreEqual" when args.Count >= 2 => ConvertCollectionAreEqual(args), + "AreNotEqual" when args.Count >= 2 => ConvertCollectionAreNotEqual(args), + "AreEquivalent" when args.Count >= 2 => ConvertCollectionAreEquivalent(args), + "AreNotEquivalent" when args.Count >= 2 => ConvertCollectionAreNotEquivalent(args), + "Contains" when args.Count >= 2 => ConvertCollectionContains(args), + "DoesNotContain" when args.Count >= 2 => ConvertCollectionDoesNotContain(args), + "IsSubsetOf" when args.Count >= 2 => ConvertCollectionIsSubsetOf(args), + "IsNotSubsetOf" when args.Count >= 2 => ConvertCollectionIsNotSubsetOf(args), + "AllItemsAreUnique" when args.Count >= 1 => ConvertCollectionAllItemsAreUnique(args), + "AllItemsAreNotNull" when args.Count >= 1 => ConvertCollectionAllItemsAreNotNull(args), + "IsEmpty" when args.Count >= 1 => ConvertCollectionIsEmpty(args), + "IsNotEmpty" when args.Count >= 1 => ConvertCollectionIsNotEmpty(args), + _ => (AssertionConversionKind.Unknown, null, false, null) + }; + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionAreEqual(SeparatedSyntaxList args) + { + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsEquivalentTo({expected}).Because({message})" + : $"await Assert.That({actual}).IsEquivalentTo({expected})"; + + return (AssertionConversionKind.CollectionAreEqual, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionAreNotEqual(SeparatedSyntaxList args) + { + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsNotEquivalentTo({expected}).Because({message})" + : $"await Assert.That({actual}).IsNotEquivalentTo({expected})"; + + return (AssertionConversionKind.CollectionAreNotEqual, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionAreEquivalent(SeparatedSyntaxList args) + { + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsEquivalentTo({expected}).Because({message})" + : $"await Assert.That({actual}).IsEquivalentTo({expected})"; + + return (AssertionConversionKind.CollectionAreEquivalent, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionAreNotEquivalent(SeparatedSyntaxList args) + { + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({actual}).IsNotEquivalentTo({expected}).Because({message})" + : $"await Assert.That({actual}).IsNotEquivalentTo({expected})"; + + return (AssertionConversionKind.CollectionAreNotEquivalent, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionContains(SeparatedSyntaxList args) + { + var collection = args[0].Expression.ToString(); + var element = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({collection}).Contains({element}).Because({message})" + : $"await Assert.That({collection}).Contains({element})"; + + return (AssertionConversionKind.CollectionContains, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionDoesNotContain(SeparatedSyntaxList args) + { + var collection = args[0].Expression.ToString(); + var element = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({collection}).DoesNotContain({element}).Because({message})" + : $"await Assert.That({collection}).DoesNotContain({element})"; + + return (AssertionConversionKind.CollectionDoesNotContain, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionIsSubsetOf(SeparatedSyntaxList args) + { + var subset = args[0].Expression.ToString(); + var superset = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({subset}).IsSubsetOf({superset}).Because({message})" + : $"await Assert.That({subset}).IsSubsetOf({superset})"; + + return (AssertionConversionKind.CollectionIsSubsetOf, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionIsNotSubsetOf(SeparatedSyntaxList args) + { + var subset = args[0].Expression.ToString(); + var superset = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That({subset}).IsNotSubsetOf({superset}).Because({message})" + : $"await Assert.That({subset}).IsNotSubsetOf({superset})"; + + return (AssertionConversionKind.CollectionIsNotSubsetOf, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionAllItemsAreUnique(SeparatedSyntaxList args) + { + var collection = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args[1]) : null; + + var assertion = message != null + ? $"await Assert.That({collection}).HasDistinctItems().Because({message})" + : $"await Assert.That({collection}).HasDistinctItems()"; + + return (AssertionConversionKind.CollectionAllItemsAreUnique, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionAllItemsAreNotNull(SeparatedSyntaxList args) + { + var collection = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args[1]) : null; + + var assertion = message != null + ? $"await Assert.That({collection}).All(x => x != null).Because({message})" + : $"await Assert.That({collection}).All(x => x != null)"; + + return (AssertionConversionKind.CollectionAllItemsAreNotNull, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionIsEmpty(SeparatedSyntaxList args) + { + var collection = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args[1]) : null; + + var assertion = message != null + ? $"await Assert.That({collection}).IsEmpty().Because({message})" + : $"await Assert.That({collection}).IsEmpty()"; + + return (AssertionConversionKind.Empty, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollectionIsNotEmpty(SeparatedSyntaxList args) + { + var collection = args[0].Expression.ToString(); + string? message = args.Count >= 2 ? GetMessageArgument(args[1]) : null; + + var assertion = message != null + ? $"await Assert.That({collection}).IsNotEmpty().Because({message})" + : $"await Assert.That({collection}).IsNotEmpty()"; + + return (AssertionConversionKind.NotEmpty, assertion, true, null); + } + + private static string? GetMessageArgument(ArgumentSyntax arg) + { + // NUnit message parameters can be string literals or expressions + if (arg.Expression is LiteralExpressionSyntax literal && + literal.IsKind(SyntaxKind.StringLiteralExpression)) + { + return arg.Expression.ToString(); + } + if (arg.Expression is InterpolatedStringExpressionSyntax) + { + return arg.Expression.ToString(); + } + // Return the expression as-is for other cases + return arg.Expression.ToString(); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertFileAssert( + string methodName, SeparatedSyntaxList args) + { + if (args.Count < 1) + return (AssertionConversionKind.Unknown, null, false, null); + + var path = args[0].Expression.ToString(); + + return methodName switch + { + "Exists" => (AssertionConversionKind.True, $"await Assert.That(File.Exists({path})).IsTrue()", true, null), + "DoesNotExist" => (AssertionConversionKind.True, $"await Assert.That(File.Exists({path})).IsFalse()", true, null), + "AreEqual" when args.Count >= 2 => ConvertFileAreEqual(args), + "AreNotEqual" when args.Count >= 2 => ConvertFileAreNotEqual(args), + _ => (AssertionConversionKind.Unknown, null, false, $"// TODO: TUnit migration - FileAssert.{methodName} has no direct equivalent") + }; + } + + private (AssertionConversionKind, string?, bool, string?) ConvertFileAreEqual(SeparatedSyntaxList args) + { + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That(new FileInfo({actual})).HasSameContentAs(new FileInfo({expected})).Because({message})" + : $"await Assert.That(new FileInfo({actual})).HasSameContentAs(new FileInfo({expected}))"; + + return (AssertionConversionKind.Equal, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertFileAreNotEqual(SeparatedSyntaxList args) + { + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That(new FileInfo({actual})).DoesNotHaveSameContentAs(new FileInfo({expected})).Because({message})" + : $"await Assert.That(new FileInfo({actual})).DoesNotHaveSameContentAs(new FileInfo({expected}))"; + + return (AssertionConversionKind.NotEqual, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertDirectoryAssert( + string methodName, SeparatedSyntaxList args) + { + if (args.Count < 1) + return (AssertionConversionKind.Unknown, null, false, null); + + var path = args[0].Expression.ToString(); + + return methodName switch + { + "Exists" => (AssertionConversionKind.True, $"await Assert.That(Directory.Exists({path})).IsTrue()", true, null), + "DoesNotExist" => (AssertionConversionKind.True, $"await Assert.That(Directory.Exists({path})).IsFalse()", true, null), + "AreEqual" when args.Count >= 2 => ConvertDirectoryAreEqual(args), + "AreNotEqual" when args.Count >= 2 => ConvertDirectoryAreNotEqual(args), + _ => (AssertionConversionKind.Unknown, null, false, $"// TODO: TUnit migration - DirectoryAssert.{methodName} has no direct equivalent") + }; + } + + private (AssertionConversionKind, string?, bool, string?) ConvertDirectoryAreEqual(SeparatedSyntaxList args) + { + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That(new DirectoryInfo({actual})).IsEquivalentTo(new DirectoryInfo({expected})).Because({message})" + : $"await Assert.That(new DirectoryInfo({actual})).IsEquivalentTo(new DirectoryInfo({expected}))"; + + return (AssertionConversionKind.Equal, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertDirectoryAreNotEqual(SeparatedSyntaxList args) + { + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + string? message = args.Count >= 3 ? GetMessageArgument(args[2]) : null; + + var assertion = message != null + ? $"await Assert.That(new DirectoryInfo({actual})).IsNotEquivalentTo(new DirectoryInfo({expected})).Because({message})" + : $"await Assert.That(new DirectoryInfo({actual})).IsNotEquivalentTo(new DirectoryInfo({expected}))"; + + return (AssertionConversionKind.NotEqual, assertion, true, null); + } + + protected override bool ShouldRemoveAttribute(AttributeSyntax node) + { + var name = MigrationHelpers.GetAttributeName(node); + + if (NUnitRemovableAttributeNames.Contains(name)) + return true; + + // Handle conditionally removable attributes + if (name == "Parallelizable") + { + // Remove Parallelizable unless it has ParallelScope.None (which converts to NotInParallel) + if (node.ArgumentList?.Arguments.Count > 0) + { + var arg = node.ArgumentList.Arguments[0].Expression.ToString(); + if (arg.Contains("None")) + return false; // Don't remove - this will be converted to NotInParallel + } + return true; // Remove - TUnit is parallel by default + } + + return false; + } + + protected override AttributeConversion? AnalyzeAttribute(AttributeSyntax node) + { + var name = MigrationHelpers.GetAttributeName(node); + + if (!NUnitAttributeNames.Contains(name)) + return null; + + // Handle TestCase specially due to complex property conversions + if (name == "TestCase") + { + return ConvertTestCaseAttribute(node); + } + + // Handle Test specially if it has properties + if (name == "Test") + { + return ConvertTestAttributeFull(node); + } + + var (newName, newArgs) = name switch + { + "Theory" => ("Test", null), + "TestCaseSource" => ("MethodDataSource", ConvertTestCaseSourceArgs(node)), + "SetUp" => ("Before", "(HookType.Test)"), + "TearDown" => ("After", "(HookType.Test)"), + "OneTimeSetUp" => ("Before", "(HookType.Class)"), + "OneTimeTearDown" => ("After", "(HookType.Class)"), + "Category" => ("Category", node.ArgumentList?.ToString()), + "Ignore" => ("Skip", node.ArgumentList?.ToString()), + "Explicit" => ("Explicit", node.ArgumentList?.ToString()), + "Description" => ("Property", ConvertDescriptionArgs(node)), + "Author" => ("Property", ConvertAuthorArgs(node)), + "Repeat" => ("Repeat", node.ArgumentList?.ToString()), + "Values" => ("Matrix", node.ArgumentList?.ToString()), + "ValueSource" => ("MatrixSourceMethod", node.ArgumentList?.ToString()), + "NonParallelizable" => ("NotInParallel", null), + "Parallelizable" => ConvertParallelizableAttribute(node), + "Platform" => ConvertPlatformAttribute(node), + "Apartment" => ConvertApartmentAttribute(node), + "ExpectedException" => (null, null), // Handled separately + "Sequential" => (null, null), // No direct equivalent - TODO needed + _ => (null, null) + }; + + if (newName == null) + return null; + + return new AttributeConversion + { + NewAttributeName = newName, + NewArgumentList = newArgs, + OriginalText = node.ToString() + }; + } + + private static string? ConvertTestCaseSourceArgs(AttributeSyntax node) + { + if (node.ArgumentList?.Arguments.Count > 0) + { + var firstArg = node.ArgumentList.Arguments[0]; + return $"({firstArg.Expression})"; + } + return null; + } + + private static string? ConvertDescriptionArgs(AttributeSyntax node) + { + if (node.ArgumentList?.Arguments.Count > 0) + { + var value = node.ArgumentList.Arguments[0].Expression.ToString(); + return $"(\"Description\", {value})"; + } + return null; + } + + private static string? ConvertAuthorArgs(AttributeSyntax node) + { + if (node.ArgumentList?.Arguments.Count > 0) + { + var value = node.ArgumentList.Arguments[0].Expression.ToString(); + return $"(\"Author\", {value})"; + } + return null; + } + + private AttributeConversion ConvertTestAttributeFull(AttributeSyntax node) + { + // Handle [Test(Description = "...", Author = "...")] + if (node.ArgumentList?.Arguments.Count == 0 || node.ArgumentList == null) + { + return new AttributeConversion + { + NewAttributeName = "Test", + NewArgumentList = null, + OriginalText = node.ToString() + }; + } + + // Check for Description, Author properties + var additionalAttributes = new List(); + foreach (var arg in node.ArgumentList.Arguments) + { + if (arg.NameEquals != null) + { + var propName = arg.NameEquals.Name.Identifier.Text; + var propValue = arg.Expression.ToString(); + + switch (propName) + { + case "Description": + additionalAttributes.Add(new AdditionalAttribute + { + Name = "Property", + Arguments = $"(\"Description\", {propValue})" + }); + break; + case "Author": + additionalAttributes.Add(new AdditionalAttribute + { + Name = "Property", + Arguments = $"(\"Author\", {propValue})" + }); + break; + case "ExpectedResult": + // ExpectedResult is handled separately by the ExpectedResult rewriter + break; + } + } + } + + // Test attribute has no positional arguments in TUnit + // Use empty string to explicitly remove arguments (null would keep them) + return new AttributeConversion + { + NewAttributeName = "Test", + NewArgumentList = "", + AdditionalAttributes = additionalAttributes.Count > 0 ? additionalAttributes : null, + OriginalText = node.ToString() + }; + } + + private AttributeConversion? ConvertTestCaseAttribute(AttributeSyntax node) + { + // [TestCase(1, TestName = "Test")] -> [Arguments(1, DisplayName = "Test")] + // [TestCase(1, Category = "Unit")] -> [Arguments(1, Categories = ["Unit"])] + // [TestCase(1, Ignore = "reason")] -> [Arguments(1, Skip = "reason")] + // [TestCase(1, IgnoreReason = "reason")] -> [Arguments(1, Skip = "reason")] + // [TestCase(1, Description = "...")] -> [Arguments(1)] + [Property("Description", "...")] + // [TestCase(1, Author = "...")] -> [Arguments(1)] + [Property("Author", "...")] + // [TestCase(1, Explicit = true)] -> [Arguments(1)] + [Explicit] + // [TestCase(1, ExplicitReason = "...")] -> [Arguments(1)] + [Explicit] + [Property("ExplicitReason", "...")] + + if (node.ArgumentList == null) + { + return new AttributeConversion + { + NewAttributeName = "Arguments", + NewArgumentList = null, + OriginalText = node.ToString() + }; + } + + var positionalArgs = new List(); + var inlineNamedArgs = new List(); + var additionalAttributes = new List(); + var hasExpectedResult = false; + + foreach (var arg in node.ArgumentList.Arguments) + { + if (arg.NameEquals != null || arg.NameColon != null) + { + // Named argument + var propName = arg.NameEquals?.Name.Identifier.Text + ?? arg.NameColon?.Name.Identifier.Text + ?? ""; + var propValue = arg.Expression.ToString(); + + switch (propName) + { + case "TestName": + // Convert to DisplayName (inline) + inlineNamedArgs.Add($"DisplayName = {propValue}"); + break; + case "Category": + // Convert to Categories array (inline) + inlineNamedArgs.Add($"Categories = [{propValue}]"); + break; + case "Ignore": + case "IgnoreReason": + // Convert to Skip (inline) + inlineNamedArgs.Add($"Skip = {propValue}"); + break; + case "Description": + // Convert to separate [Property] attribute + additionalAttributes.Add(new AdditionalAttribute + { + Name = "Property", + Arguments = $"(\"Description\", {propValue})" + }); + break; + case "Author": + // Convert to separate [Property] attribute + additionalAttributes.Add(new AdditionalAttribute + { + Name = "Property", + Arguments = $"(\"Author\", {propValue})" + }); + break; + case "Explicit": + // Convert to separate [Explicit] attribute (only if true) + if (propValue == "true") + { + additionalAttributes.Add(new AdditionalAttribute { Name = "Explicit" }); + } + break; + case "ExplicitReason": + // Convert to [Explicit] + [Property] + additionalAttributes.Add(new AdditionalAttribute { Name = "Explicit" }); + additionalAttributes.Add(new AdditionalAttribute + { + Name = "Property", + Arguments = $"(\"ExplicitReason\", {propValue})" + }); + break; + case "ExpectedResult": + // ExpectedResult is handled separately by the ExpectedResult rewriter + hasExpectedResult = true; + break; + } + } + else + { + // Positional argument - keep as-is + positionalArgs.Add(arg.ToString()); + } + } + + // If this TestCase has ExpectedResult, don't convert it here. + // Let the NUnitExpectedResultRewriter handle the complete transformation. + if (hasExpectedResult) + { + return null; + } + + // Build the new argument list + var allArgs = new List(positionalArgs); + allArgs.AddRange(inlineNamedArgs); + var newArgList = allArgs.Count > 0 ? $"({string.Join(", ", allArgs)})" : null; + + return new AttributeConversion + { + NewAttributeName = "Arguments", + NewArgumentList = newArgList, + AdditionalAttributes = additionalAttributes.Count > 0 ? additionalAttributes : null, + OriginalText = node.ToString() + }; + } + + private (string?, string?) ConvertParallelizableAttribute(AttributeSyntax node) + { + // [Parallelizable(ParallelScope.None)] -> [NotInParallel] + // [Parallelizable(ParallelScope.Self)] -> Keep as default (TUnit default is parallel) + // [Parallelizable] with no args -> Keep as default + if (node.ArgumentList?.Arguments.Count > 0) + { + var arg = node.ArgumentList.Arguments[0].Expression.ToString(); + if (arg.Contains("None")) + { + return ("NotInParallel", ""); // Empty string = no arguments + } + } + // Default parallelizable - TUnit is parallel by default, so we can remove this + return (null, null); + } + + private (string?, string?) ConvertPlatformAttribute(AttributeSyntax node) + { + // [Platform(Include = "Win")] -> [RunOn(OS.Windows)] + // [Platform(Exclude = "Linux")] -> [ExcludeOn(OS.Linux)] + if (node.ArgumentList?.Arguments.Count > 0) + { + foreach (var arg in node.ArgumentList.Arguments) + { + var nameText = arg.NameEquals?.Name.Identifier.Text; + var valueText = arg.Expression.ToString().Trim('"'); + + if (nameText == "Include") + { + var os = MapPlatformToOS(valueText); + return ("RunOn", $"({os})"); + } + if (nameText == "Exclude") + { + var os = MapPlatformToOS(valueText); + return ("ExcludeOn", $"({os})"); + } + } + } + return (null, null); + } + + private static string MapPlatformToOS(string platform) + { + // Handle multiple platforms separated by comma: "Win,Linux" -> "OS.Windows | OS.Linux" + if (platform.Contains(",")) + { + var platforms = platform.Split(',') + .Select(p => MapSinglePlatformToOS(p.Trim())) + .ToArray(); + return string.Join(" | ", platforms); + } + + return MapSinglePlatformToOS(platform); + } + + private static string MapSinglePlatformToOS(string platform) + { + return platform.ToLowerInvariant() switch + { + "win" or "windows" or "win32" or "win64" => "OS.Windows", + "linux" or "unix" => "OS.Linux", + "macos" or "osx" or "macosx" => "OS.MacOS", + _ => $"OS.{platform}" + }; + } + + private (string?, string?) ConvertApartmentAttribute(AttributeSyntax node) + { + // [Apartment(ApartmentState.STA)] -> [STAThreadExecutor] + if (node.ArgumentList?.Arguments.Count > 0) + { + var arg = node.ArgumentList.Arguments[0].Expression.ToString(); + if (arg.Contains("STA")) + { + return ("STAThreadExecutor", ""); // Empty string to remove arguments + } + } + return (null, null); + } + + protected override ParameterAttributeConversion? AnalyzeParameterAttribute(AttributeSyntax attr, ParameterSyntax parameter) + { + var attrName = MigrationHelpers.GetAttributeName(attr); + + if (attrName == "Range") + { + return ConvertRangeAttribute(attr, parameter); + } + + return null; + } + + protected override CompilationUnitSyntax AnalyzeMethodsForMissingAttributes(CompilationUnitSyntax root) + { + var currentRoot = root; + + // Find methods that have TestCase/Arguments attributes but no Test attribute + foreach (var method in root.DescendantNodes().OfType()) + { + try + { + var hasTestCaseOrArguments = false; + var hasTestAttribute = false; + + foreach (var attrList in method.AttributeLists) + { + foreach (var attr in attrList.Attributes) + { + var attrName = MigrationHelpers.GetAttributeName(attr); + if (attrName == "TestCase" || attrName == "Arguments" || attrName == "TestCaseSource") + { + hasTestCaseOrArguments = true; + } + else if (attrName == "Test" || attrName == "Theory") + { + hasTestAttribute = true; + } + } + } + + if (hasTestCaseOrArguments && !hasTestAttribute) + { + // Need to add [Test] attribute + var addition = new MethodAttributeAddition + { + AttributeCode = "Test", + OriginalText = method.Identifier.Text + }; + Plan.MethodAttributeAdditions.Add(addition); + + // Annotate the method so we can find it during transformation + var nodeToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(m => m.Span == method.Span); + + if (nodeToAnnotate != null) + { + var annotatedNode = nodeToAnnotate.WithAdditionalAnnotations(addition.Annotation); + currentRoot = currentRoot.ReplaceNode(nodeToAnnotate, annotatedNode); + } + } + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "MethodMissingAttributeAnalysis", + Description = ex.Message, + OriginalCode = method.Identifier.Text, + Exception = ex + }); + } + } + + return currentRoot; + } + + protected override CompilationUnitSyntax AnalyzeMethodSignatures(CompilationUnitSyntax root) + { + // First call base to handle async conversions + var currentRoot = base.AnalyzeMethodSignatures(root); + + // Then handle lifecycle method visibility - find methods with SetUp/TearDown etc. + // that are not public and need to be made public + var lifecycleAttributeNames = new HashSet { "SetUp", "TearDown", "OneTimeSetUp", "OneTimeTearDown" }; + var processedMethodSpans = new HashSet(); + + foreach (var method in OriginalRoot.DescendantNodes().OfType()) + { + // Check if this method has a lifecycle attribute + var hasLifecycleAttribute = method.AttributeLists + .SelectMany(al => al.Attributes) + .Any(attr => lifecycleAttributeNames.Contains(MigrationHelpers.GetAttributeName(attr))); + + if (!hasLifecycleAttribute) continue; + + // Check if already public + if (method.Modifiers.Any(SyntaxKind.PublicKeyword)) continue; + + // Check if we already processed this method + if (processedMethodSpans.Contains(method.Span)) continue; + processedMethodSpans.Add(method.Span); + + try + { + // Find the method in the current root + var currentMethod = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(m => m.Span == method.Span || m.Identifier.Text == method.Identifier.Text); + + if (currentMethod == null) continue; + + // Check if there's already a MethodSignatureChange for this method + var existingChange = Plan.MethodSignatureChanges + .FirstOrDefault(c => currentMethod.HasAnnotation(c.Annotation)); + + if (existingChange != null) + { + // Update the existing change to also make public + // Since we can't modify the existing record, we need to remove and re-add + // Actually, this is tricky. Let's just add a new change if none exists. + // If there's an existing change, the transformer will handle both. + continue; + } + + var change = new MethodSignatureChange + { + MakePublic = true, + OriginalText = $"visibility:{method.Identifier.Text}" + }; + Plan.MethodSignatureChanges.Add(change); + + var annotatedMethod = currentMethod.WithAdditionalAnnotations(change.Annotation); + currentRoot = currentRoot.ReplaceNode(currentMethod, annotatedMethod); + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "LifecycleMethodVisibilityAnalysis", + Description = ex.Message, + OriginalCode = method.Identifier.Text, + Exception = ex + }); + } + } + + return currentRoot; + } + + private ParameterAttributeConversion? ConvertRangeAttribute(AttributeSyntax attr, ParameterSyntax parameter) + { + // [Range(1, 5)] -> [MatrixRange(1, 5)] + // [Range(1.0, 5.0)] -> [MatrixRange(1.0, 5.0)] + // [Range(1L, 100L)] -> [MatrixRange(1L, 100L)] + // [Range(1.0f, 5.0f)] -> [MatrixRange(1.0f, 5.0f)] + + if (attr.ArgumentList?.Arguments.Count < 2) + return null; + + // Determine the type from the first argument literal or the parameter type + var firstArg = attr.ArgumentList.Arguments[0].Expression.ToString(); + var typeArg = InferRangeType(firstArg, parameter); + + return new ParameterAttributeConversion + { + NewAttributeName = $"MatrixRange<{typeArg}>", + NewArgumentList = null, // Keep original arguments + OriginalText = attr.ToString() + }; + } + + private static string InferRangeType(string literal, ParameterSyntax parameter) + { + // Check for literal suffix + if (literal.EndsWith("L", StringComparison.OrdinalIgnoreCase)) + return "long"; + if (literal.EndsWith("f", StringComparison.OrdinalIgnoreCase)) + return "float"; + if (literal.EndsWith("d", StringComparison.OrdinalIgnoreCase) || literal.Contains(".")) + return "double"; + + // Fall back to parameter type if available + var paramType = parameter.Type?.ToString(); + if (!string.IsNullOrEmpty(paramType)) + { + return paramType switch + { + "long" => "long", + "float" => "float", + "double" => "double", + "decimal" => "decimal", + "short" => "short", + "byte" => "byte", + _ => "int" + }; + } + + return "int"; + } + + protected override bool ShouldRemoveBaseType(BaseTypeSyntax baseType) + { + // NUnit doesn't have common base types to remove like xUnit's IClassFixture + return false; + } + + protected override void AnalyzeUsings() + { + Plan.UsingPrefixesToRemove.Add("NUnit"); + // TUnit usings are handled automatically by MigrationHelpers + } +} diff --git a/TUnit.Analyzers.CodeFixers/TwoPhase/XUnitTwoPhaseAnalyzer.cs b/TUnit.Analyzers.CodeFixers/TwoPhase/XUnitTwoPhaseAnalyzer.cs new file mode 100644 index 0000000000..04d4b85592 --- /dev/null +++ b/TUnit.Analyzers.CodeFixers/TwoPhase/XUnitTwoPhaseAnalyzer.cs @@ -0,0 +1,1772 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using TUnit.Analyzers.CodeFixers.Base.TwoPhase; + +namespace TUnit.Analyzers.CodeFixers.TwoPhase; + +/// +/// Phase 1 analyzer for xUnit to TUnit migration. +/// Collects all conversion targets while the semantic model is valid. +/// +public class XUnitTwoPhaseAnalyzer : MigrationAnalyzer +{ + private static readonly HashSet XUnitAssertMethods = new() + { + "Equal", "NotEqual", "Same", "NotSame", "StrictEqual", + "True", "False", + "Null", "NotNull", + "Empty", "NotEmpty", "Single", "Contains", "DoesNotContain", "All", + "Throws", "ThrowsAsync", "ThrowsAny", "ThrowsAnyAsync", + "IsType", "IsNotType", "IsAssignableFrom", + "StartsWith", "EndsWith", "Matches", "DoesNotMatch", + "InRange", "NotInRange", + "Fail", "Skip", "Collection", + "PropertyChanged", "PropertyChangedAsync", + "Raises", "RaisesAsync", "RaisesAny", "RaisesAnyAsync", + "Subset", "Superset", "ProperSubset", "ProperSuperset", + "Distinct", "Equivalent" + }; + + // Track assertions that are assigned to variables (should not be converted) + private readonly HashSet _variableAssignedAssertions = new(); + + private static readonly HashSet XUnitAttributeNames = new() + { + "Fact", "Theory", "InlineData", "MemberData", "ClassData", + "Trait", "Collection", "CollectionDefinition" + }; + + private static readonly HashSet XUnitRemovableAttributeNames = new() + { + // No longer removing Collection - we handle it specially now + }; + + private static readonly HashSet XUnitBaseTypes = new() + { + "IClassFixture", "ICollectionFixture", "IAsyncLifetime" + }; + + public XUnitTwoPhaseAnalyzer(SemanticModel semanticModel, Compilation compilation) + : base(semanticModel, compilation) + { + } + + protected override IEnumerable FindAssertionNodes(CompilationUnitSyntax root) + { + // First pass: detect assertions that are assigned to variables + foreach (var invocation in root.DescendantNodes().OfType()) + { + if (IsXUnitAssertion(invocation) && IsAssignedToVariable(invocation)) + { + _variableAssignedAssertions.Add(invocation); + } + } + + // Return all xUnit assertions (including variable-assigned ones, handled specially in AnalyzeAssertion) + return root.DescendantNodes() + .OfType() + .Where(IsXUnitAssertion); + } + + private static bool IsAssignedToVariable(InvocationExpressionSyntax invocation) + { + // Check if this invocation is the right-hand side of a variable declaration + // e.g., var ex = Assert.Throws(...) + var parent = invocation.Parent; + + // Check for variable declarator: var ex = Assert.Throws(...) + if (parent is EqualsValueClauseSyntax equalsClause && + equalsClause.Parent is VariableDeclaratorSyntax) + { + return true; + } + + // Check for assignment expression: ex = Assert.Throws(...) + if (parent is AssignmentExpressionSyntax) + { + return true; + } + + return false; + } + + private bool IsXUnitAssertion(InvocationExpressionSyntax invocation) + { + // Check for Assert.X pattern + if (invocation.Expression is MemberAccessExpressionSyntax memberAccess) + { + // Syntax-based check first (fast) + if (memberAccess.Expression is IdentifierNameSyntax { Identifier.Text: "Assert" }) + { + var methodName = memberAccess.Name.Identifier.Text; + if (XUnitAssertMethods.Contains(methodName)) + { + // Semantic check to confirm it's xUnit + var symbolInfo = SemanticModel.GetSymbolInfo(invocation); + if (symbolInfo.Symbol is IMethodSymbol methodSymbol) + { + var containingType = methodSymbol.ContainingType?.ToDisplayString(); + return containingType?.StartsWith("Xunit.Assert") == true; + } + } + } + } + + return false; + } + + protected override AssertionConversion? AnalyzeAssertion(InvocationExpressionSyntax node) + { + if (node.Expression is not MemberAccessExpressionSyntax memberAccess) + return null; + + var methodName = memberAccess.Name.Identifier.Text; + var arguments = node.ArgumentList.Arguments; + + // Skip Throws/ThrowsAsync if assigned to a variable (result is used) + if ((methodName is "Throws" or "ThrowsAsync" or "ThrowsAny" or "ThrowsAnyAsync") && + _variableAssignedAssertions.Contains(node)) + { + return null; // Don't convert - keep as-is + } + + var (kind, replacementCode, introducesAwait, todoComment) = methodName switch + { + "Equal" => ConvertEqual(arguments), + "NotEqual" => ConvertNotEqual(arguments), + "True" => ConvertTrue(arguments), + "False" => ConvertFalse(arguments), + "Null" => ConvertNull(arguments), + "NotNull" => ConvertNotNull(arguments), + "Same" => ConvertSame(arguments), + "NotSame" => ConvertNotSame(arguments), + "StrictEqual" => ConvertStrictEqual(arguments), + "Empty" => ConvertEmpty(arguments), + "NotEmpty" => ConvertNotEmpty(arguments), + "Single" => ConvertSingle(arguments), + "Contains" => ConvertContains(arguments), + "DoesNotContain" => ConvertDoesNotContain(arguments), + "Throws" => ConvertThrows(memberAccess, arguments), + "ThrowsAsync" => ConvertThrowsAsync(memberAccess, arguments), + "ThrowsAny" => ConvertThrowsAny(memberAccess, arguments), + "ThrowsAnyAsync" => ConvertThrowsAnyAsync(memberAccess, arguments), + "IsType" => ConvertIsType(memberAccess, arguments), + "IsNotType" => ConvertIsNotType(memberAccess, arguments), + "IsAssignableFrom" => ConvertIsAssignableFrom(memberAccess, arguments), + "StartsWith" => ConvertStartsWith(arguments), + "EndsWith" => ConvertEndsWith(arguments), + "Matches" => ConvertMatches(arguments), + "DoesNotMatch" => ConvertDoesNotMatch(arguments), + "InRange" => ConvertInRange(arguments), + "NotInRange" => ConvertNotInRange(arguments), + "Fail" => (AssertionConversionKind.Fail, "Assert.Fail()", false, (string?)null), + "All" => ConvertAll(arguments), + "Collection" => ConvertCollection(arguments), + "Subset" => ConvertSubset(arguments), + "Superset" => ConvertSuperset(arguments), + "ProperSubset" => ConvertProperSubset(arguments), + "ProperSuperset" => ConvertProperSuperset(arguments), + "Distinct" => ConvertDistinct(arguments), + "Equivalent" => ConvertEquivalent(arguments), + _ => (AssertionConversionKind.Unknown, null, false, (string?)null) + }; + + if (replacementCode == null) + return null; + + return new AssertionConversion + { + Kind = kind, + ReplacementCode = replacementCode, + IntroducesAwait = introducesAwait, + TodoComment = todoComment, + OriginalText = node.ToString() + }; + } + + #region Assertion Conversions + + private (AssertionConversionKind, string?, bool, string?) ConvertEqual(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.Equal, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + + return (AssertionConversionKind.Equal, $"await Assert.That({actual}).IsEqualTo({expected})", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertNotEqual(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.NotEqual, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + + return (AssertionConversionKind.NotEqual, $"await Assert.That({actual}).IsNotEqualTo({expected})", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertTrue(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.True, null, false, null); + + var condition = args[0].Expression.ToString(); + var message = GetMessageArgument(args, 1); + var assertion = $"await Assert.That({condition}).IsTrue()"; + if (message != null) + { + assertion += $".Because({message})"; + } + return (AssertionConversionKind.True, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertFalse(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.False, null, false, null); + + var condition = args[0].Expression.ToString(); + var message = GetMessageArgument(args, 1); + var assertion = $"await Assert.That({condition}).IsFalse()"; + if (message != null) + { + assertion += $".Because({message})"; + } + return (AssertionConversionKind.False, assertion, true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertNull(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.Null, null, false, null); + + var value = args[0].Expression.ToString(); + return (AssertionConversionKind.Null, $"await Assert.That({value}).IsNull()", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertNotNull(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.NotNull, null, false, null); + + var value = args[0].Expression.ToString(); + return (AssertionConversionKind.NotNull, $"await Assert.That({value}).IsNotNull()", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertSame(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.Same, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + + return (AssertionConversionKind.Same, $"await Assert.That({actual}).IsSameReferenceAs({expected})", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertNotSame(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.NotSame, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + + return (AssertionConversionKind.NotSame, $"await Assert.That({actual}).IsNotSameReferenceAs({expected})", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertStrictEqual(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.StrictEqual, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + + return (AssertionConversionKind.StrictEqual, $"await Assert.That({actual}).IsStrictlyEqualTo({expected})", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertEmpty(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.Empty, null, false, null); + + var collection = args[0].Expression.ToString(); + return (AssertionConversionKind.Empty, $"await Assert.That({collection}).IsEmpty()", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertNotEmpty(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.NotEmpty, null, false, null); + + var collection = args[0].Expression.ToString(); + return (AssertionConversionKind.NotEmpty, $"await Assert.That({collection}).IsNotEmpty()", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertSingle(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.Single, null, false, null); + + var collection = args[0].Expression.ToString(); + return (AssertionConversionKind.Single, $"await Assert.That({collection}).HasSingleItem()", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertContains(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.Contains, null, false, null); + + var expected = args[0].Expression.ToString(); + var collection = args[1].Expression.ToString(); + + return (AssertionConversionKind.Contains, $"await Assert.That({collection}).Contains({expected})", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertDoesNotContain(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.DoesNotContain, null, false, null); + + var expected = args[0].Expression.ToString(); + var collection = args[1].Expression.ToString(); + + return (AssertionConversionKind.DoesNotContain, $"await Assert.That({collection}).DoesNotContain({expected})", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertThrows(MemberAccessExpressionSyntax memberAccess, SeparatedSyntaxList args) + { + // TUnit has Assert.Throws(action) with the same API as xUnit - no conversion needed! + // TUnit's Assert.Throws returns TException just like xUnit. + return (AssertionConversionKind.Throws, null, false, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertThrowsAsync(MemberAccessExpressionSyntax memberAccess, SeparatedSyntaxList args) + { + // TUnit has Assert.ThrowsAsync(action) with similar API to xUnit - no conversion needed! + // TUnit's Assert.ThrowsAsync returns ThrowsAssertion which is awaitable like xUnit's Task. + return (AssertionConversionKind.ThrowsAsync, null, false, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertThrowsAny(MemberAccessExpressionSyntax memberAccess, SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.ThrowsAny, null, false, null); + + var typeArg = GetGenericTypeArgument(memberAccess); + var action = args[0].Expression.ToString(); + + if (typeArg != null) + { + return (AssertionConversionKind.ThrowsAny, $"await Assert.That({action}).Throws<{typeArg}>()", true, null); + } + + return (AssertionConversionKind.ThrowsAny, $"await Assert.That({action}).ThrowsException()", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertThrowsAnyAsync(MemberAccessExpressionSyntax memberAccess, SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.ThrowsAnyAsync, null, false, null); + + var typeArg = GetGenericTypeArgument(memberAccess); + var action = args[0].Expression.ToString(); + + if (typeArg != null) + { + return (AssertionConversionKind.ThrowsAnyAsync, $"await Assert.That({action}).Throws<{typeArg}>()", true, null); + } + + return (AssertionConversionKind.ThrowsAnyAsync, $"await Assert.That({action}).ThrowsException()", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertIsType(MemberAccessExpressionSyntax memberAccess, SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.IsType, null, false, null); + + var typeArg = GetGenericTypeArgument(memberAccess); + var value = args[0].Expression.ToString(); + + if (typeArg != null) + { + return (AssertionConversionKind.IsType, $"await Assert.That({value}).IsTypeOf<{typeArg}>()", true, null); + } + + return (AssertionConversionKind.IsType, null, false, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertIsNotType(MemberAccessExpressionSyntax memberAccess, SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.IsNotType, null, false, null); + + var typeArg = GetGenericTypeArgument(memberAccess); + var value = args[0].Expression.ToString(); + + if (typeArg != null) + { + return (AssertionConversionKind.IsNotType, $"await Assert.That({value}).IsNotTypeOf<{typeArg}>()", true, null); + } + + return (AssertionConversionKind.IsNotType, null, false, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertIsAssignableFrom(MemberAccessExpressionSyntax memberAccess, SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.IsAssignableFrom, null, false, null); + + var typeArg = GetGenericTypeArgument(memberAccess); + var value = args[0].Expression.ToString(); + + if (typeArg != null) + { + return (AssertionConversionKind.IsAssignableFrom, $"await Assert.That({value}).IsAssignableTo<{typeArg}>()", true, null); + } + + return (AssertionConversionKind.IsAssignableFrom, null, false, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertStartsWith(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.StartsWith, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + + return (AssertionConversionKind.StartsWith, $"await Assert.That({actual}).StartsWith({expected})", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertEndsWith(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.EndsWith, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + + return (AssertionConversionKind.EndsWith, $"await Assert.That({actual}).EndsWith({expected})", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertMatches(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.Matches, null, false, null); + + var pattern = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + + return (AssertionConversionKind.Matches, $"await Assert.That({actual}).Matches({pattern})", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertInRange(SeparatedSyntaxList args) + { + if (args.Count < 3) return (AssertionConversionKind.InRange, null, false, null); + + var actual = args[0].Expression.ToString(); + var low = args[1].Expression.ToString(); + var high = args[2].Expression.ToString(); + + return (AssertionConversionKind.InRange, $"await Assert.That({actual}).IsInRange({low},{high})", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertNotInRange(SeparatedSyntaxList args) + { + if (args.Count < 3) return (AssertionConversionKind.NotInRange, null, false, null); + + var actual = args[0].Expression.ToString(); + var low = args[1].Expression.ToString(); + var high = args[2].Expression.ToString(); + + return (AssertionConversionKind.NotInRange, $"await Assert.That({actual}).IsNotInRange({low},{high})", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertAll(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.All, null, false, null); + + var collection = args[0].Expression.ToString(); + var actionExpression = args[1].Expression; + + // Try to extract predicate from Assert.True/False patterns + var predicate = TryConvertActionToPredicate(actionExpression); + + return (AssertionConversionKind.All, $"await Assert.That({collection}).All({predicate})", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertCollection(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.Collection, null, false, null); + + var collection = args[0].Expression.ToString(); + // Count the element inspectors (args after the first one) + var inspectorCount = args.Count - 1; + + var todoComment = "// TODO: TUnit migration - Assert.Collection had element inspectors. Manually add assertions for each element."; + return (AssertionConversionKind.Collection, $"await Assert.That({collection}).HasCount({inspectorCount})", true, todoComment); + } + + private string TryConvertActionToPredicate(ExpressionSyntax actionExpression) + { + // Try to convert xUnit action patterns to TUnit predicates + // Pattern: item => Assert.True(item > 0) -> item => item > 0 + // Pattern: item => Assert.False(item < 0) -> item => !(item < 0) + + if (actionExpression is SimpleLambdaExpressionSyntax simpleLambda) + { + var parameter = simpleLambda.Parameter.Identifier.Text; + var body = simpleLambda.Body; + + // Check if body is an xUnit assertion invocation + if (body is InvocationExpressionSyntax invocation && + invocation.Expression is MemberAccessExpressionSyntax memberAccess && + memberAccess.Expression is IdentifierNameSyntax { Identifier.Text: "Assert" }) + { + var methodName = memberAccess.Name.Identifier.Text; + var invocationArgs = invocation.ArgumentList.Arguments; + + var predicateBody = methodName switch + { + "True" when invocationArgs.Count >= 1 => invocationArgs[0].Expression.ToString(), + "False" when invocationArgs.Count >= 1 => $"!({invocationArgs[0].Expression})", + "NotNull" when invocationArgs.Count >= 1 => $"{invocationArgs[0].Expression} != null", + "Null" when invocationArgs.Count >= 1 => $"{invocationArgs[0].Expression} == null", + _ => null + }; + + if (predicateBody != null) + { + return $"{parameter} => {predicateBody}"; + } + } + } + else if (actionExpression is ParenthesizedLambdaExpressionSyntax parenLambda && + parenLambda.ParameterList.Parameters.Count == 1) + { + var parameter = parenLambda.ParameterList.Parameters[0].Identifier.Text; + var body = parenLambda.Body; + + if (body is InvocationExpressionSyntax invocation && + invocation.Expression is MemberAccessExpressionSyntax memberAccess && + memberAccess.Expression is IdentifierNameSyntax { Identifier.Text: "Assert" }) + { + var methodName = memberAccess.Name.Identifier.Text; + var invocationArgs = invocation.ArgumentList.Arguments; + + var predicateBody = methodName switch + { + "True" when invocationArgs.Count >= 1 => invocationArgs[0].Expression.ToString(), + "False" when invocationArgs.Count >= 1 => $"!({invocationArgs[0].Expression})", + "NotNull" when invocationArgs.Count >= 1 => $"{invocationArgs[0].Expression} != null", + "Null" when invocationArgs.Count >= 1 => $"{invocationArgs[0].Expression} == null", + _ => null + }; + + if (predicateBody != null) + { + return $"{parameter} => {predicateBody}"; + } + } + } + + // Fallback: return the action as-is (may not work, but better than nothing) + return actionExpression.ToString(); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertDoesNotMatch(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.Matches, null, false, null); + + var pattern = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + + return (AssertionConversionKind.Matches, $"await Assert.That({actual}).DoesNotMatch({pattern})", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertSubset(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.Contains, null, false, null); + + var superset = args[0].Expression.ToString(); + var subset = args[1].Expression.ToString(); + + return (AssertionConversionKind.Contains, $"await Assert.That({superset}).Contains({subset}).IgnoringOrder()", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertSuperset(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.Contains, null, false, null); + + var subset = args[0].Expression.ToString(); + var superset = args[1].Expression.ToString(); + + return (AssertionConversionKind.Contains, $"await Assert.That({superset}).Contains({subset}).IgnoringOrder()", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertProperSubset(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.Contains, null, false, null); + + // xUnit: Assert.ProperSubset(expectedSuperset, actual) + // actual should be a proper subset of expectedSuperset + // TUnit: Assert.That(first_arg).IsSubsetOf(second_arg) with TODO comment + var first = args[0].Expression.ToString(); + var second = args[1].Expression.ToString(); + + var todoComment = "// TODO: TUnit migration - ProperSubset requires strict subset (not equal). Add additional assertion if needed."; + return (AssertionConversionKind.Contains, $"await Assert.That({first}).IsSubsetOf({second})", true, todoComment); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertProperSuperset(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.Contains, null, false, null); + + // xUnit: Assert.ProperSuperset(expectedSubset, actual) + // actual should be a proper superset of expectedSubset + // TUnit: Assert.That(first_arg).IsSupersetOf(second_arg) with TODO comment + var first = args[0].Expression.ToString(); + var second = args[1].Expression.ToString(); + + var todoComment = "// TODO: TUnit migration - ProperSuperset requires strict superset (not equal). Add additional assertion if needed."; + return (AssertionConversionKind.Contains, $"await Assert.That({first}).IsSupersetOf({second})", true, todoComment); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertDistinct(SeparatedSyntaxList args) + { + if (args.Count < 1) return (AssertionConversionKind.Collection, null, false, null); + + var collection = args[0].Expression.ToString(); + + return (AssertionConversionKind.Collection, $"await Assert.That({collection}).HasDistinctItems()", true, null); + } + + private (AssertionConversionKind, string?, bool, string?) ConvertEquivalent(SeparatedSyntaxList args) + { + if (args.Count < 2) return (AssertionConversionKind.Equal, null, false, null); + + var expected = args[0].Expression.ToString(); + var actual = args[1].Expression.ToString(); + + return (AssertionConversionKind.Equal, $"await Assert.That({actual}).IsEquivalentTo({expected})", true, null); + } + + private static string? GetGenericTypeArgument(MemberAccessExpressionSyntax memberAccess) + { + if (memberAccess.Name is GenericNameSyntax genericName) + { + return genericName.TypeArgumentList.Arguments.FirstOrDefault()?.ToString(); + } + return null; + } + + /// + /// Gets the message argument from xUnit assertions if present. + /// In xUnit, message is typically the last string argument. + /// Returns the full argument syntax (e.g., "\"my message\"") for use with .Because(). + /// + private static string? GetMessageArgument(SeparatedSyntaxList args, int startIndex) + { + // Look for a string literal argument at or after the startIndex + for (int i = startIndex; i < args.Count; i++) + { + var arg = args[i]; + // Named argument check (e.g., userMessage: "...") + if (arg.NameColon?.Name.Identifier.Text is "userMessage" or "message") + { + return arg.Expression.ToString(); + } + // String literal check + if (arg.Expression is LiteralExpressionSyntax literal && + literal.IsKind(SyntaxKind.StringLiteralExpression)) + { + return arg.Expression.ToString(); + } + // Interpolated string check + if (arg.Expression is InterpolatedStringExpressionSyntax) + { + return arg.Expression.ToString(); + } + } + return null; + } + + #endregion + + #region Attribute Analysis + + protected override bool ShouldRemoveAttribute(AttributeSyntax node) + { + var name = GetAttributeName(node); + return XUnitRemovableAttributeNames.Contains(name); + } + + protected override AttributeConversion? AnalyzeAttribute(AttributeSyntax node) + { + var name = GetAttributeName(node); + + if (!XUnitAttributeNames.Contains(name)) + return null; + + var conversion = name switch + { + "Fact" or "FactAttribute" => ConvertTestAttribute(node), + "Theory" or "TheoryAttribute" => ConvertTestAttribute(node), + "Trait" or "TraitAttribute" => new AttributeConversion + { + NewAttributeName = "Property", + NewArgumentList = null, // Keep original arguments + OriginalText = node.ToString() + }, + "InlineData" or "InlineDataAttribute" => new AttributeConversion + { + NewAttributeName = "Arguments", + NewArgumentList = null, // Keep original arguments + OriginalText = node.ToString() + }, + "MemberData" or "MemberDataAttribute" => ConvertMemberData(node), + "ClassData" or "ClassDataAttribute" => ConvertClassData(node), + "Collection" or "CollectionAttribute" => ConvertCollection(node), + "CollectionDefinition" or "CollectionDefinitionAttribute" => new AttributeConversion + { + NewAttributeName = "System.Obsolete", + NewArgumentList = "", // Remove arguments + OriginalText = node.ToString() + }, + _ => null + }; + + return conversion; + } + + private AttributeConversion ConvertTestAttribute(AttributeSyntax node) + { + // Check for Skip argument: [Fact(Skip = "reason")] or [Theory(Skip = "reason")] + var skipArg = node.ArgumentList?.Arguments + .FirstOrDefault(a => a.NameEquals?.Name.Identifier.ValueText == "Skip"); + + if (skipArg != null) + { + // Extract the skip reason and create additional Skip attribute + var skipReason = skipArg.Expression.ToString(); + return new AttributeConversion + { + NewAttributeName = "Test", + NewArgumentList = "", // Remove the Skip argument + OriginalText = node.ToString(), + AdditionalAttributes = new List + { + new AdditionalAttribute + { + Name = "Skip", + Arguments = $"({skipReason})" + } + } + }; + } + + return new AttributeConversion + { + NewAttributeName = "Test", + NewArgumentList = "", // Remove any arguments + OriginalText = node.ToString() + }; + } + + private AttributeConversion? ConvertCollection(AttributeSyntax node) + { + // [Collection("name")] on a test class needs to: + // 1. Find the CollectionDefinition class with matching name + // 2. Check if it has DisableParallelization = true (add [NotInParallel]) + // 3. Find ICollectionFixture interface and add ClassDataSource(Shared = SharedType.Keyed, Key = "name") + + var collectionNameArg = node.ArgumentList?.Arguments.FirstOrDefault()?.Expression; + if (collectionNameArg == null) + return null; + + var collectionName = collectionNameArg.ToString().Trim('"'); + + // Find the CollectionDefinition class + var collectionDefinition = FindCollectionDefinition(collectionName); + if (collectionDefinition == null) + { + // No CollectionDefinition found - just remove the attribute + return new AttributeConversion + { + NewAttributeName = "System.Obsolete", + NewArgumentList = "", + OriginalText = node.ToString() + }; + } + + // Check for DisableParallelization + var disableParallelization = HasDisableParallelization(collectionDefinition); + + // Find ICollectionFixture interface + var fixtureType = GetCollectionFixtureType(collectionDefinition); + + var additionalAttributes = new List(); + + // Add NotInParallel if DisableParallelization = true + if (disableParallelization) + { + additionalAttributes.Add(new AdditionalAttribute + { + Name = "NotInParallel", + Arguments = null + }); + } + + // If there's a fixture type, use ClassDataSource + if (fixtureType != null) + { + // The conversion produces ClassDataSource(Shared = SharedType.Keyed, Key = "name") + // We need to return this as the primary attribute conversion + var keyArgument = collectionNameArg.ToString(); // Keep the original string including quotes + + return new AttributeConversion + { + NewAttributeName = $"ClassDataSource<{fixtureType}>", + NewArgumentList = $"(Shared = SharedType.Keyed, Key = {keyArgument})", + OriginalText = node.ToString(), + AdditionalAttributes = additionalAttributes.Count > 0 ? additionalAttributes : null + }; + } + else if (disableParallelization) + { + // No fixture, just add NotInParallel + return new AttributeConversion + { + NewAttributeName = "NotInParallel", + NewArgumentList = "", + OriginalText = node.ToString() + }; + } + + // No fixture, no parallelization disable - just remove the attribute + return new AttributeConversion + { + NewAttributeName = "System.Obsolete", + NewArgumentList = "", + OriginalText = node.ToString() + }; + } + + private ClassDeclarationSyntax? FindCollectionDefinition(string collectionName) + { + // Search all classes in the compilation for [CollectionDefinition("name")] + foreach (var tree in Compilation.SyntaxTrees) + { + var root = tree.GetRoot(); + foreach (var classDecl in root.DescendantNodes().OfType()) + { + foreach (var attrList in classDecl.AttributeLists) + { + foreach (var attr in attrList.Attributes) + { + var attrName = GetAttributeName(attr); + if (attrName == "CollectionDefinition" || attrName == "CollectionDefinitionAttribute") + { + var nameArg = attr.ArgumentList?.Arguments.FirstOrDefault()?.Expression?.ToString()?.Trim('"'); + if (nameArg == collectionName) + { + return classDecl; + } + } + } + } + } + } + return null; + } + + private bool HasDisableParallelization(ClassDeclarationSyntax collectionDefinition) + { + // Check for [CollectionDefinition("name", DisableParallelization = true)] + foreach (var attrList in collectionDefinition.AttributeLists) + { + foreach (var attr in attrList.Attributes) + { + var attrName = GetAttributeName(attr); + if (attrName == "CollectionDefinition" || attrName == "CollectionDefinitionAttribute") + { + var disableArg = attr.ArgumentList?.Arguments + .FirstOrDefault(a => a.NameEquals?.Name.Identifier.Text == "DisableParallelization"); + if (disableArg?.Expression is LiteralExpressionSyntax literal && + literal.Token.ValueText == "true") + { + return true; + } + } + } + } + return false; + } + + private string? GetCollectionFixtureType(ClassDeclarationSyntax collectionDefinition) + { + // Find ICollectionFixture in the base list + if (collectionDefinition.BaseList == null) + return null; + + foreach (var baseType in collectionDefinition.BaseList.Types) + { + var typeName = baseType.Type.ToString(); + if (typeName.StartsWith("ICollectionFixture<")) + { + // Extract the type argument + if (baseType.Type is GenericNameSyntax genericName) + { + return genericName.TypeArgumentList.Arguments.FirstOrDefault()?.ToString(); + } + } + } + + return null; + } + + private AttributeConversion? ConvertMemberData(AttributeSyntax node) + { + // [MemberData(nameof(Data))] -> [MethodDataSource(nameof(Data))] + // [MemberData(nameof(Data), MemberType = typeof(Foo))] -> [MethodDataSource(typeof(Foo), nameof(Data))] + var args = node.ArgumentList?.Arguments; + if (args == null || args.Value.Count == 0) + return null; + + var memberName = args.Value[0].Expression.ToString(); + + // Check for MemberType named argument + var memberTypeArg = args.Value.Skip(1) + .FirstOrDefault(a => a.NameEquals?.Name.Identifier.Text == "MemberType"); + + if (memberTypeArg != null) + { + var memberType = memberTypeArg.Expression.ToString(); + return new AttributeConversion + { + NewAttributeName = "MethodDataSource", + NewArgumentList = $"({memberType}, {memberName})", + OriginalText = node.ToString() + }; + } + + return new AttributeConversion + { + NewAttributeName = "MethodDataSource", + NewArgumentList = $"({memberName})", + OriginalText = node.ToString() + }; + } + + private AttributeConversion? ConvertClassData(AttributeSyntax node) + { + // [ClassData(typeof(TestDataGenerator))] -> [MethodDataSource(typeof(TestDataGenerator), "GetEnumerator")] + var args = node.ArgumentList?.Arguments; + if (args == null || args.Value.Count == 0) + return null; + + var typeArg = args.Value[0].Expression.ToString(); + + return new AttributeConversion + { + NewAttributeName = "MethodDataSource", + NewArgumentList = $"({typeArg}, \"GetEnumerator\")", + OriginalText = node.ToString() + }; + } + + private static string GetAttributeName(AttributeSyntax attribute) + { + return attribute.Name switch + { + SimpleNameSyntax simpleName => simpleName.Identifier.Text, + QualifiedNameSyntax qualifiedName => qualifiedName.Right.Identifier.Text, + _ => "" + }; + } + + #endregion + + #region Base Type Analysis + + protected override CompilationUnitSyntax AnalyzeBaseTypes(CompilationUnitSyntax root) + { + var classNodes = OriginalRoot.DescendantNodes().OfType().ToList(); + var currentRoot = root; + + foreach (var classNode in classNodes) + { + if (classNode.BaseList == null) continue; + + var hasIAsyncLifetime = false; + var classFixtureTypes = new List(); + + foreach (var originalBaseType in classNode.BaseList.Types) + { + try + { + var typeName = originalBaseType.Type.ToString(); + + // Check for IAsyncLifetime + if (typeName == "IAsyncLifetime" || IsIAsyncLifetimeType(originalBaseType)) + { + hasIAsyncLifetime = true; + } + + // Check for IClassFixture + if (typeName.StartsWith("IClassFixture<") || IsIClassFixtureType(originalBaseType)) + { + var fixtureType = ExtractGenericTypeArgument(originalBaseType); + if (fixtureType != null) + { + classFixtureTypes.Add(fixtureType); + } + } + + if (ShouldRemoveBaseType(originalBaseType)) + { + var removal = new BaseTypeRemoval + { + TypeName = typeName, + OriginalText = originalBaseType.ToString() + }; + Plan.BaseTypeRemovals.Add(removal); + + var nodeToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(n => n.Span == originalBaseType.Span); + + if (nodeToAnnotate != null) + { + var annotatedNode = nodeToAnnotate.WithAdditionalAnnotations(removal.Annotation); + currentRoot = currentRoot.ReplaceNode(nodeToAnnotate, annotatedNode); + } + } + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "BaseTypeAnalysis", + Description = ex.Message, + OriginalCode = originalBaseType.ToString(), + Exception = ex + }); + } + } + + // If this class implements IAsyncLifetime, handle based on whether it's a test class + if (hasIAsyncLifetime) + { + try + { + var isTestClass = HasTestMethods(classNode); + + if (isTestClass) + { + // Test class: Add [Before(Test)]/[After(Test)] attributes + currentRoot = AnalyzeLifecycleMethods(currentRoot, classNode); + } + else + { + // Non-test class: Add IAsyncInitializer and IAsyncDisposable interfaces + currentRoot = AnalyzeNonTestLifecycleMethods(currentRoot, classNode); + } + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "LifecycleMethodAnalysis", + Description = ex.Message, + OriginalCode = classNode.Identifier.Text, + Exception = ex + }); + } + } + + // If this class implements IClassFixture, add ClassDataSource attribute + foreach (var fixtureType in classFixtureTypes) + { + try + { + var classAttrAddition = new ClassAttributeAddition + { + AttributeCode = $"ClassDataSource<{fixtureType}>(Shared = SharedType.PerClass)", + OriginalText = classNode.Identifier.Text + }; + Plan.ClassAttributeAdditions.Add(classAttrAddition); + + // Find the class in current root and annotate + var classToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(c => c.Span == classNode.Span); + + if (classToAnnotate != null) + { + var annotatedClass = classToAnnotate.WithAdditionalAnnotations(classAttrAddition.Annotation); + currentRoot = currentRoot.ReplaceNode(classToAnnotate, annotatedClass); + } + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "ClassAttributeAddition", + Description = ex.Message, + OriginalCode = classNode.Identifier.Text, + Exception = ex + }); + } + } + } + + return currentRoot; + } + + private CompilationUnitSyntax AnalyzeLifecycleMethods(CompilationUnitSyntax root, ClassDeclarationSyntax originalClass) + { + var currentRoot = root; + + // Find InitializeAsync and DisposeAsync methods in the original class + foreach (var method in originalClass.Members.OfType()) + { + try + { + var methodName = method.Identifier.Text; + + if (methodName == "InitializeAsync") + { + // Add [Before(Test)] attribute and change return type to Task + var methodAttrAddition = new MethodAttributeAddition + { + AttributeCode = "Before(Test)", + NewReturnType = method.ReturnType.ToString() == "ValueTask" ? "Task" : null, + OriginalText = method.Identifier.Text + }; + Plan.MethodAttributeAdditions.Add(methodAttrAddition); + + // Find method in current root and annotate + var methodToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(m => m.Span == method.Span); + + if (methodToAnnotate != null) + { + var annotatedMethod = methodToAnnotate.WithAdditionalAnnotations(methodAttrAddition.Annotation); + currentRoot = currentRoot.ReplaceNode(methodToAnnotate, annotatedMethod); + } + } + else if (methodName == "DisposeAsync") + { + // Add [After(Test)] attribute and change return type to Task + var methodAttrAddition = new MethodAttributeAddition + { + AttributeCode = "After(Test)", + NewReturnType = method.ReturnType.ToString() == "ValueTask" ? "Task" : null, + OriginalText = method.Identifier.Text + }; + Plan.MethodAttributeAdditions.Add(methodAttrAddition); + + // Find method in current root and annotate + var methodToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(m => m.Span == method.Span); + + if (methodToAnnotate != null) + { + var annotatedMethod = methodToAnnotate.WithAdditionalAnnotations(methodAttrAddition.Annotation); + currentRoot = currentRoot.ReplaceNode(methodToAnnotate, annotatedMethod); + } + } + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "LifecycleMethodAnalysis", + Description = ex.Message, + OriginalCode = method.Identifier.Text, + Exception = ex + }); + } + } + + return currentRoot; + } + + private bool HasTestMethods(ClassDeclarationSyntax classNode) + { + // Check if any method in the class has test-related attributes + foreach (var method in classNode.Members.OfType()) + { + foreach (var attrList in method.AttributeLists) + { + foreach (var attr in attrList.Attributes) + { + var attrName = GetAttributeName(attr); + if (attrName is "Fact" or "FactAttribute" or "Theory" or "TheoryAttribute" or + "Test" or "TestAttribute") + { + return true; + } + } + } + } + + return false; + } + + private CompilationUnitSyntax AnalyzeNonTestLifecycleMethods(CompilationUnitSyntax root, ClassDeclarationSyntax originalClass) + { + var currentRoot = root; + + // For non-test classes implementing IAsyncLifetime, we need to: + // 1. Add IAsyncInitializer and IAsyncDisposable interfaces + // 2. Change InitializeAsync return type from ValueTask to Task + + // Add base types for the class + var classToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(c => c.Span == originalClass.Span); + + if (classToAnnotate != null) + { + // Add IAsyncInitializer base type + var initializerAddition = new BaseTypeAddition + { + TypeName = "IAsyncInitializer", + OriginalText = originalClass.Identifier.Text + }; + Plan.BaseTypeAdditions.Add(initializerAddition); + + var annotatedClass = classToAnnotate.WithAdditionalAnnotations(initializerAddition.Annotation); + currentRoot = currentRoot.ReplaceNode(classToAnnotate, annotatedClass); + + // Find class again after modification + classToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(c => c.HasAnnotation(initializerAddition.Annotation)); + + if (classToAnnotate != null) + { + // Add IAsyncDisposable base type + var disposableAddition = new BaseTypeAddition + { + TypeName = "IAsyncDisposable", + OriginalText = originalClass.Identifier.Text + }; + Plan.BaseTypeAdditions.Add(disposableAddition); + + annotatedClass = classToAnnotate.WithAdditionalAnnotations(disposableAddition.Annotation); + currentRoot = currentRoot.ReplaceNode(classToAnnotate, annotatedClass); + } + } + + // Change InitializeAsync return type from ValueTask to Task + foreach (var method in originalClass.Members.OfType()) + { + if (method.Identifier.Text == "InitializeAsync" && + method.ReturnType.ToString() == "ValueTask") + { + var signatureChange = new MethodSignatureChange + { + ChangeValueTaskToTask = true, + OriginalText = method.Identifier.Text + }; + Plan.MethodSignatureChanges.Add(signatureChange); + + var methodToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(m => m.Span == method.Span); + + if (methodToAnnotate != null) + { + var annotatedMethod = methodToAnnotate.WithAdditionalAnnotations(signatureChange.Annotation); + currentRoot = currentRoot.ReplaceNode(methodToAnnotate, annotatedMethod); + } + } + } + + return currentRoot; + } + + private bool IsIAsyncLifetimeType(BaseTypeSyntax baseType) + { + var typeInfo = SemanticModel.GetTypeInfo(baseType.Type); + return typeInfo.Type?.ToDisplayString() == "Xunit.IAsyncLifetime"; + } + + private bool IsIClassFixtureType(BaseTypeSyntax baseType) + { + var typeInfo = SemanticModel.GetTypeInfo(baseType.Type); + var displayString = typeInfo.Type?.ToDisplayString() ?? ""; + return displayString.StartsWith("Xunit.IClassFixture<"); + } + + private string? ExtractGenericTypeArgument(BaseTypeSyntax baseType) + { + if (baseType.Type is GenericNameSyntax genericName) + { + return genericName.TypeArgumentList.Arguments.FirstOrDefault()?.ToString(); + } + + // Try semantic model for qualified names + var typeInfo = SemanticModel.GetTypeInfo(baseType.Type); + if (typeInfo.Type is INamedTypeSymbol namedType && namedType.TypeArguments.Length > 0) + { + return namedType.TypeArguments[0].ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat); + } + + return null; + } + + protected override bool ShouldRemoveBaseType(BaseTypeSyntax baseType) + { + var typeName = baseType.Type.ToString(); + + // Check for generic interface patterns like IClassFixture + foreach (var xunitType in XUnitBaseTypes) + { + if (typeName.StartsWith(xunitType + "<") || typeName == xunitType) + { + return true; + } + } + + // Semantic check for xUnit types + var typeInfo = SemanticModel.GetTypeInfo(baseType.Type); + if (typeInfo.Type != null) + { + var ns = typeInfo.Type.ContainingNamespace?.ToDisplayString(); + if (ns?.StartsWith("Xunit") == true) + { + return true; + } + } + + return false; + } + + #endregion + + #region Member Analysis + + protected override CompilationUnitSyntax AnalyzeMembers(CompilationUnitSyntax root) + { + var currentRoot = root; + + // Find ITestOutputHelper fields and properties on ORIGINAL tree (for semantic analysis) + var members = OriginalRoot.DescendantNodes() + .Where(n => + (n is PropertyDeclarationSyntax prop && IsTestOutputHelperType(prop.Type)) || + (n is FieldDeclarationSyntax field && IsTestOutputHelperType(field.Declaration.Type))) + .ToList(); + + foreach (var originalMember in members) + { + try + { + var removal = new MemberRemoval + { + MemberName = GetMemberName(originalMember), + OriginalText = originalMember.ToString() + }; + + Plan.MemberRemovals.Add(removal); + + // Find corresponding node in current tree by span + var nodeToAnnotate = currentRoot.DescendantNodes() + .FirstOrDefault(n => n.Span == originalMember.Span); + + if (nodeToAnnotate != null) + { + var annotatedMember = nodeToAnnotate.WithAdditionalAnnotations(removal.Annotation); + currentRoot = currentRoot.ReplaceNode(nodeToAnnotate, annotatedMember); + } + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "MemberAnalysis", + Description = ex.Message, + OriginalCode = originalMember.ToString(), + Exception = ex + }); + } + } + + return currentRoot; + } + + private bool IsTestOutputHelperType(TypeSyntax type) + { + var typeName = type.ToString(); + if (typeName == "ITestOutputHelper") + return true; + + // Use semantic model on type from ORIGINAL tree + var typeInfo = SemanticModel.GetTypeInfo(type); + return typeInfo.Type?.ToDisplayString() == "Xunit.Abstractions.ITestOutputHelper"; + } + + private static string GetMemberName(SyntaxNode member) + { + return member switch + { + PropertyDeclarationSyntax prop => prop.Identifier.Text, + FieldDeclarationSyntax field => field.Declaration.Variables.FirstOrDefault()?.Identifier.Text ?? "", + _ => "" + }; + } + + #endregion + + #region Constructor Parameter Analysis + + protected override CompilationUnitSyntax AnalyzeConstructorParameters(CompilationUnitSyntax root) + { + var currentRoot = root; + + // Find parameters on ORIGINAL tree (for semantic analysis) + var parameters = OriginalRoot.DescendantNodes() + .OfType() + .Where(p => p.Type != null && IsTestOutputHelperType(p.Type)) + .ToList(); + + foreach (var originalParam in parameters) + { + try + { + var removal = new ConstructorParameterRemoval + { + ParameterName = originalParam.Identifier.Text, + ParameterType = originalParam.Type?.ToString() ?? "", + OriginalText = originalParam.ToString() + }; + + Plan.ConstructorParameterRemovals.Add(removal); + + // Find corresponding node in current tree by span + var nodeToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(p => p.Span == originalParam.Span); + + if (nodeToAnnotate != null) + { + var annotatedParam = nodeToAnnotate.WithAdditionalAnnotations(removal.Annotation); + currentRoot = currentRoot.ReplaceNode(nodeToAnnotate, annotatedParam); + } + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "ConstructorParameterAnalysis", + Description = ex.Message, + OriginalCode = originalParam.ToString(), + Exception = ex + }); + } + } + + return currentRoot; + } + + #endregion + + protected override void AnalyzeUsings() + { + Plan.UsingPrefixesToRemove.Add("Xunit"); + // TUnit usings are handled automatically by MigrationHelpers + } + + #region Special Invocation Analysis + + protected override CompilationUnitSyntax AnalyzeSpecialInvocations(CompilationUnitSyntax root) + { + var currentRoot = root; + + // Analyze Record.Exception calls + currentRoot = AnalyzeRecordExceptionCalls(currentRoot); + + // Analyze ITestOutputHelper.WriteLine calls + currentRoot = AnalyzeTestOutputHelperCalls(currentRoot); + + return currentRoot; + } + + private CompilationUnitSyntax AnalyzeRecordExceptionCalls(CompilationUnitSyntax root) + { + var currentRoot = root; + + // Find Record.Exception calls on the ORIGINAL tree + var recordExceptionCalls = OriginalRoot.DescendantNodes() + .OfType() + .Where(IsRecordExceptionCall) + .ToList(); + + foreach (var originalCall in recordExceptionCalls) + { + try + { + // Check if it's assigned to a variable + var parent = originalCall.Parent; + if (parent is not EqualsValueClauseSyntax equalsClause || + equalsClause.Parent is not VariableDeclaratorSyntax declarator) + { + continue; // Only handle variable assignments + } + + var variableName = declarator.Identifier.Text; + + // Extract the lambda body + var lambda = originalCall.ArgumentList.Arguments.FirstOrDefault()?.Expression; + if (lambda == null) continue; + + string tryBlockBody; + if (lambda is ParenthesizedLambdaExpressionSyntax parenLambda) + { + tryBlockBody = ExtractLambdaBody(parenLambda.Body); + } + else if (lambda is SimpleLambdaExpressionSyntax simpleLambda) + { + tryBlockBody = ExtractLambdaBody(simpleLambda.Body); + } + else + { + continue; + } + + var conversion = new RecordExceptionConversion + { + VariableName = variableName, + TryBlockBody = tryBlockBody, + OriginalText = originalCall.ToString() + }; + + Plan.RecordExceptionConversions.Add(conversion); + + // Find the containing statement (the variable declaration statement) + var declarationStatement = declarator.Ancestors() + .OfType() + .FirstOrDefault(); + + if (declarationStatement != null) + { + // Find corresponding node in current tree + var nodeToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(n => n.Span == declarationStatement.Span); + + if (nodeToAnnotate != null) + { + var annotatedNode = nodeToAnnotate.WithAdditionalAnnotations(conversion.Annotation); + currentRoot = currentRoot.ReplaceNode(nodeToAnnotate, annotatedNode); + } + } + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "RecordExceptionAnalysis", + Description = ex.Message, + OriginalCode = originalCall.ToString(), + Exception = ex + }); + } + } + + return currentRoot; + } + + private bool IsRecordExceptionCall(InvocationExpressionSyntax invocation) + { + if (invocation.Expression is MemberAccessExpressionSyntax memberAccess && + memberAccess.Expression is IdentifierNameSyntax { Identifier.Text: "Record" } && + memberAccess.Name.Identifier.Text == "Exception") + { + // Semantic check to confirm it's xUnit + var symbolInfo = SemanticModel.GetSymbolInfo(invocation); + if (symbolInfo.Symbol is IMethodSymbol methodSymbol) + { + var containingType = methodSymbol.ContainingType?.ToDisplayString(); + return containingType?.StartsWith("Xunit.Record") == true; + } + } + return false; + } + + private static string ExtractLambdaBody(CSharpSyntaxNode body) + { + return body switch + { + BlockSyntax block => string.Join("\n", block.Statements.Select(s => s.ToString())), + ExpressionSyntax expr => expr.ToString() + ";", + _ => body.ToString() + }; + } + + private CompilationUnitSyntax AnalyzeTestOutputHelperCalls(CompilationUnitSyntax root) + { + var currentRoot = root; + + // Find ITestOutputHelper.WriteLine calls on the ORIGINAL tree + var testOutputHelperCalls = OriginalRoot.DescendantNodes() + .OfType() + .Where(IsTestOutputHelperWriteLineCall) + .ToList(); + + foreach (var originalCall in testOutputHelperCalls) + { + try + { + // Build the Console.WriteLine replacement + var arguments = originalCall.ArgumentList.ToString(); + var replacementCode = $"Console.WriteLine{arguments}"; + + var replacement = new InvocationReplacement + { + ReplacementCode = replacementCode, + OriginalText = originalCall.ToString() + }; + + Plan.InvocationReplacements.Add(replacement); + + // Find corresponding node in current tree + var nodeToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(n => n.Span == originalCall.Span); + + if (nodeToAnnotate != null) + { + var annotatedNode = nodeToAnnotate.WithAdditionalAnnotations(replacement.Annotation); + currentRoot = currentRoot.ReplaceNode(nodeToAnnotate, annotatedNode); + } + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "TestOutputHelperAnalysis", + Description = ex.Message, + OriginalCode = originalCall.ToString(), + Exception = ex + }); + } + } + + return currentRoot; + } + + private bool IsTestOutputHelperWriteLineCall(InvocationExpressionSyntax invocation) + { + if (invocation.Expression is MemberAccessExpressionSyntax memberAccess && + memberAccess.Name.Identifier.Text == "WriteLine") + { + // Check the receiver type semantically + var receiverInfo = SemanticModel.GetTypeInfo(memberAccess.Expression); + var typeName = receiverInfo.Type?.ToDisplayString(); + if (typeName == "Xunit.Abstractions.ITestOutputHelper") + { + return true; + } + + // Fallback: syntactic check for common patterns + // e.g., _testOutputHelper.WriteLine, TestOutputHelper.WriteLine, testOutputHelper.WriteLine + var receiverName = memberAccess.Expression switch + { + IdentifierNameSyntax id => id.Identifier.Text, + MemberAccessExpressionSyntax ma => ma.Name.Identifier.Text, + _ => null + }; + + if (receiverName != null && + (receiverName.Contains("testOutputHelper", StringComparison.OrdinalIgnoreCase) || + receiverName.Contains("TestOutputHelper", StringComparison.OrdinalIgnoreCase) || + receiverName.EndsWith("OutputHelper", StringComparison.OrdinalIgnoreCase))) + { + return true; + } + } + return false; + } + + #endregion + + #region TheoryData Analysis + + protected override CompilationUnitSyntax AnalyzeTheoryData(CompilationUnitSyntax root) + { + var currentRoot = root; + + // Find all TheoryData types in the original tree (field and property declarations) + var theoryDataNodes = OriginalRoot.DescendantNodes() + .OfType() + .Where(g => g.Identifier.Text == "TheoryData") + .ToList(); + + foreach (var originalGeneric in theoryDataNodes) + { + try + { + // Get the type argument + var typeArg = originalGeneric.TypeArgumentList.Arguments.FirstOrDefault(); + if (typeArg == null) continue; + + var elementType = typeArg.ToString(); + + // Create annotations for both the type and the object creation + var typeAnnotation = new SyntaxAnnotation("TUnitMigration", Guid.NewGuid().ToString()); + var creationAnnotation = new SyntaxAnnotation("TUnitMigration", Guid.NewGuid().ToString()); + + var conversion = new TheoryDataConversion + { + ElementType = elementType, + TypeAnnotation = typeAnnotation, + CreationAnnotation = creationAnnotation, + OriginalText = originalGeneric.ToString() + }; + Plan.TheoryDataConversions.Add(conversion); + + // Find and annotate the GenericName in the current tree + var typeToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(n => n.Span == originalGeneric.Span); + + if (typeToAnnotate != null) + { + var annotatedType = typeToAnnotate.WithAdditionalAnnotations(typeAnnotation); + currentRoot = currentRoot.ReplaceNode(typeToAnnotate, annotatedType); + } + + // Also find the corresponding object creation expression and annotate it + // The object creation is typically a sibling or descendant node in the variable declaration + var parent = originalGeneric.Parent; + while (parent != null && parent is not VariableDeclarationSyntax) + { + parent = parent.Parent; + } + + if (parent is VariableDeclarationSyntax varDecl) + { + foreach (var declarator in varDecl.Variables) + { + if (declarator.Initializer?.Value is BaseObjectCreationExpressionSyntax creation) + { + var creationToAnnotate = currentRoot.DescendantNodes() + .OfType() + .FirstOrDefault(n => n.Span == creation.Span); + + if (creationToAnnotate != null) + { + var annotatedCreation = creationToAnnotate.WithAdditionalAnnotations(creationAnnotation); + currentRoot = currentRoot.ReplaceNode(creationToAnnotate, annotatedCreation); + } + } + } + } + } + catch (Exception ex) + { + Plan.Failures.Add(new ConversionFailure + { + Phase = "TheoryDataAnalysis", + Description = ex.Message, + OriginalCode = originalGeneric.ToString(), + Exception = ex + }); + } + } + + return currentRoot; + } + + #endregion +} diff --git a/TUnit.Analyzers.CodeFixers/XUnitMigrationCodeFixProvider.cs b/TUnit.Analyzers.CodeFixers/XUnitMigrationCodeFixProvider.cs index aabff8af2e..87e55ed2ec 100644 --- a/TUnit.Analyzers.CodeFixers/XUnitMigrationCodeFixProvider.cs +++ b/TUnit.Analyzers.CodeFixers/XUnitMigrationCodeFixProvider.cs @@ -1,11 +1,11 @@ -using System.Collections.Immutable; using System.Composition; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CodeActions; using Microsoft.CodeAnalysis.CodeFixes; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using TUnit.Analyzers.CodeFixers.Base; +using TUnit.Analyzers.CodeFixers.Base.TwoPhase; +using TUnit.Analyzers.CodeFixers.TwoPhase; namespace TUnit.Analyzers.CodeFixers; @@ -18,14 +18,23 @@ public class XUnitMigrationCodeFixProvider : BaseMigrationCodeFixProvider protected override bool ShouldAddTUnitUsings() => true; + protected override MigrationAnalyzer? CreateTwoPhaseAnalyzer(SemanticModel semanticModel, Compilation compilation) + { + return new XUnitTwoPhaseAnalyzer(semanticModel, compilation); + } + + // The following methods are required by the base class but are only used in the legacy + // rewriter-based approach. Since XUnit always uses the two-phase architecture (above), + // these implementations are never called - they exist only to satisfy the abstract contract. + protected override AttributeRewriter CreateAttributeRewriter(Compilation compilation) { - return new XUnitAttributeRewriter(); + return new PassThroughAttributeRewriter(); } protected override CSharpSyntaxRewriter CreateAssertionRewriter(SemanticModel semanticModel, Compilation compilation) { - return new XUnitAssertionRewriter(semanticModel); + return new PassThroughRewriter(); } protected override CSharpSyntaxRewriter CreateBaseTypeRewriter(SemanticModel semanticModel, Compilation compilation) @@ -38,1804 +47,41 @@ protected override CSharpSyntaxRewriter CreateLifecycleRewriter(Compilation comp return new PassThroughRewriter(); } - private class PassThroughRewriter : CSharpSyntaxRewriter - { - } - - protected override CompilationUnitSyntax ApplyFrameworkSpecificConversions(CompilationUnitSyntax compilationUnit, SemanticModel semanticModel, Compilation compilation) - { - // Use the original syntax tree from the semantic model, not from the (potentially modified) compilation unit - // After assertion rewriting, compilationUnit.SyntaxTree is a new tree not in the compilation - var syntaxTree = semanticModel.SyntaxTree; - SyntaxNode updatedRoot = compilationUnit; - - updatedRoot = UpdateInitializeDispose(compilation, updatedRoot); - UpdateSyntaxTrees(ref compilation, ref syntaxTree, ref updatedRoot); - - updatedRoot = UpdateClassAttributes(compilation, updatedRoot); - UpdateSyntaxTrees(ref compilation, ref syntaxTree, ref updatedRoot); - - updatedRoot = RemoveInterfacesAndBaseClasses(compilation, updatedRoot); - UpdateSyntaxTrees(ref compilation, ref syntaxTree, ref updatedRoot); - - updatedRoot = ConvertTheoryData(compilation, updatedRoot); - UpdateSyntaxTrees(ref compilation, ref syntaxTree, ref updatedRoot); - - updatedRoot = ConvertTestOutputHelpers(ref compilation, ref syntaxTree, updatedRoot); - UpdateSyntaxTrees(ref compilation, ref syntaxTree, ref updatedRoot); - - updatedRoot = ConvertRecordException(ref compilation, ref syntaxTree, updatedRoot); - UpdateSyntaxTrees(ref compilation, ref syntaxTree, ref updatedRoot); - - return (CompilationUnitSyntax)updatedRoot; - } - - private static SyntaxNode ConvertTestOutputHelpers(ref Compilation compilation, ref SyntaxTree syntaxTree, SyntaxNode root) - { - var currentRoot = root; - - var compilationValue = compilation; - - while (currentRoot.DescendantNodes() - .OfType() - .FirstOrDefault(x => IsTestOutputHelperInvocation(compilationValue, x)) - is { } invocationExpressionSyntax) - { - var memberAccessExpressionSyntax = (MemberAccessExpressionSyntax) invocationExpressionSyntax.Expression; - - currentRoot = currentRoot.ReplaceNode( - invocationExpressionSyntax, - invocationExpressionSyntax.WithExpression( - SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - SyntaxFactory.IdentifierName("Console"), - SyntaxFactory.IdentifierName(memberAccessExpressionSyntax.Name.Identifier.Text) - ) - ) - ); - - UpdateSyntaxTrees(ref compilation, ref syntaxTree, ref currentRoot); - compilationValue = compilation; - } - - while (currentRoot.DescendantNodes() - .OfType() - .FirstOrDefault(x => x.Type?.TryGetInferredMemberName() == "ITestOutputHelper") - is { } parameterSyntax) - { - currentRoot = currentRoot.RemoveNode(parameterSyntax, SyntaxRemoveOptions.KeepNoTrivia)!; - } - - var membersToRemove = currentRoot.DescendantNodes() - .Where(n => (n is PropertyDeclarationSyntax prop && prop.Type.TryGetInferredMemberName() == "ITestOutputHelper") || - (n is FieldDeclarationSyntax field && field.Declaration.Type.TryGetInferredMemberName() == "ITestOutputHelper")) - .ToList(); - - if (membersToRemove.Count > 0) - { - currentRoot = currentRoot.RemoveNodes(membersToRemove, SyntaxRemoveOptions.KeepNoTrivia)!; - } - - return currentRoot; - } - - private static bool IsTestOutputHelperInvocation(Compilation compilation, InvocationExpressionSyntax invocationExpressionSyntax) - { - var semanticModel = compilation.GetSemanticModel(invocationExpressionSyntax.SyntaxTree); - - var symbolInfo = semanticModel.GetSymbolInfo(invocationExpressionSyntax); - - if (symbolInfo.Symbol is not IMethodSymbol methodSymbol) - { - return false; - } - - if (invocationExpressionSyntax.Expression is not MemberAccessExpressionSyntax) - { - return false; - } - - return methodSymbol.ContainingType?.ToDisplayString(DisplayFormats.FullyQualifiedGenericWithGlobalPrefix) - is "global::Xunit.Abstractions.ITestOutputHelper" or "global::Xunit.ITestOutputHelper"; - } - - private static SyntaxNode ConvertRecordException(ref Compilation compilation, ref SyntaxTree syntaxTree, SyntaxNode root) - { - var currentRoot = root; - var compilationValue = compilation; - - // Find local declarations with Record.Exception() or Record.ExceptionAsync() - while (currentRoot.DescendantNodes() - .OfType() - .FirstOrDefault(x => IsRecordExceptionDeclaration(compilationValue, x)) - is { } localDeclaration) - { - var variableDeclarator = localDeclaration.Declaration.Variables.First(); - var variableName = variableDeclarator.Identifier.Text; - var invocation = variableDeclarator.Initializer?.Value as InvocationExpressionSyntax; - - if (invocation == null) - { - break; - } - - // Get the action/func argument - var actionArg = invocation.ArgumentList.Arguments.FirstOrDefault()?.Expression; - if (actionArg == null) - { - break; - } - - // Check if this is async (Record.ExceptionAsync) - var isAsync = invocation.Expression is MemberAccessExpressionSyntax memberAccess && - memberAccess.Name.Identifier.Text == "ExceptionAsync"; - - // Extract the body from the lambda/action - var actionBody = ExtractActionBody(actionArg); - - // Create the try-catch replacement - // Exception? ex = null; - // try { actionBody; } catch (Exception e) { ex = e; } - var statements = CreateTryCatchStatements(variableName, actionBody, isAsync, localDeclaration.GetLeadingTrivia()); - - // Replace the local declaration with the try-catch statements - var parent = localDeclaration.Parent; - if (parent is BlockSyntax block) - { - var index = block.Statements.IndexOf(localDeclaration); - var newStatements = block.Statements - .RemoveAt(index) - .InsertRange(index, statements); - var newBlock = block.WithStatements(newStatements); - currentRoot = currentRoot.ReplaceNode(block, newBlock); - } - else - { - // If not in a block, just replace with the first statement (may not work perfectly) - currentRoot = currentRoot.ReplaceNode(localDeclaration, statements.First()); - } - - UpdateSyntaxTrees(ref compilation, ref syntaxTree, ref currentRoot); - compilationValue = compilation; - } - - return currentRoot; - } - - private static bool IsRecordExceptionDeclaration(Compilation compilation, LocalDeclarationStatementSyntax localDeclaration) - { - var variableDeclarator = localDeclaration.Declaration.Variables.FirstOrDefault(); - if (variableDeclarator?.Initializer?.Value is not InvocationExpressionSyntax invocation) - { - return false; - } - - if (invocation.Expression is not MemberAccessExpressionSyntax memberAccess) - { - return false; - } - - // Check if it's Record.Exception or Record.ExceptionAsync by name first (fast check) - if (memberAccess.Expression is not IdentifierNameSyntax { Identifier.Text: "Record" }) - { - return false; - } - - if (memberAccess.Name.Identifier.Text is not ("Exception" or "ExceptionAsync")) - { - return false; - } - - // Verify with semantic model - var semanticModel = compilation.GetSemanticModel(invocation.SyntaxTree); - var symbolInfo = semanticModel.GetSymbolInfo(invocation); - - if (symbolInfo.Symbol is not IMethodSymbol methodSymbol) - { - return false; - } - - return methodSymbol.ContainingType?.ToDisplayString(DisplayFormats.FullyQualifiedGenericWithGlobalPrefix) - is "global::Xunit.Record" or "global::Xunit.Assert"; - } - - private static StatementSyntax ExtractActionBody(ExpressionSyntax actionExpression) - { - // Handle lambda expressions: () => SomeMethod() or () => { statements } - if (actionExpression is SimpleLambdaExpressionSyntax simpleLambda) - { - return ConvertLambdaBodyToStatement(simpleLambda.Body); - } - - if (actionExpression is ParenthesizedLambdaExpressionSyntax parenLambda) - { - return ConvertLambdaBodyToStatement(parenLambda.Body); - } - - // Handle method group or direct invocation - // For Action delegates, wrap in invocation - return SyntaxFactory.ExpressionStatement( - SyntaxFactory.InvocationExpression(actionExpression)); - } - - private static StatementSyntax ConvertLambdaBodyToStatement(CSharpSyntaxNode body) - { - if (body is BlockSyntax block) - { - // If it's a single statement block, extract the statement - if (block.Statements.Count == 1) - { - return block.Statements[0]; - } - // Otherwise return the block as-is (will need to be a compound statement) - return block; - } - - // Handle throw expressions - convert to throw statement - if (body is ThrowExpressionSyntax throwExpression) - { - return SyntaxFactory.ThrowStatement(throwExpression.Expression); - } - - if (body is ExpressionSyntax expression) - { - return SyntaxFactory.ExpressionStatement(expression); - } - - // Fallback - wrap in expression statement - return SyntaxFactory.ExpressionStatement( - SyntaxFactory.IdentifierName("/* TODO: Convert lambda body */")); - } - - private static IEnumerable CreateTryCatchStatements( - string variableName, - StatementSyntax actionBody, - bool isAsync, - SyntaxTriviaList leadingTrivia) - { - // Extract the base indentation from leading trivia - var indentation = leadingTrivia.Where(t => t.IsKind(SyntaxKind.WhitespaceTrivia)).LastOrDefault(); - var indentString = indentation.ToFullString(); - - // Exception? variableName = null; - var nullableExceptionType = SyntaxFactory.NullableType( - SyntaxFactory.IdentifierName("Exception")); - - var declarationStatement = SyntaxFactory.LocalDeclarationStatement( - SyntaxFactory.VariableDeclaration(nullableExceptionType) - .WithVariables(SyntaxFactory.SingletonSeparatedList( - SyntaxFactory.VariableDeclarator(variableName) - .WithInitializer(SyntaxFactory.EqualsValueClause( - SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)))))) - .WithLeadingTrivia(leadingTrivia) - .WithTrailingTrivia(SyntaxFactory.CarriageReturnLineFeed); - - // Prepare the action body with proper indentation (inside try block) - var tryBodyIndent = indentString + " "; - StatementSyntax tryBodyStatement; - - if (actionBody is BlockSyntax blockBody) - { - // If it's a block, use its statements directly - tryBodyStatement = blockBody - .WithOpenBraceToken(blockBody.OpenBraceToken.WithLeadingTrivia(SyntaxFactory.Whitespace(tryBodyIndent))) - .WithCloseBraceToken(blockBody.CloseBraceToken.WithLeadingTrivia(SyntaxFactory.Whitespace(tryBodyIndent))); - } - else - { - tryBodyStatement = actionBody - .WithLeadingTrivia(SyntaxFactory.Whitespace(tryBodyIndent)) - .WithTrailingTrivia(SyntaxFactory.CarriageReturnLineFeed); - } - - // If async, we may need to await the action - if (isAsync && tryBodyStatement is ExpressionStatementSyntax exprStmt) - { - var awaitExpr = SyntaxFactory.AwaitExpression( - SyntaxFactory.Token(SyntaxKind.AwaitKeyword).WithTrailingTrivia(SyntaxFactory.Space), - exprStmt.Expression); - tryBodyStatement = SyntaxFactory.ExpressionStatement(awaitExpr) - .WithLeadingTrivia(SyntaxFactory.Whitespace(tryBodyIndent)) - .WithTrailingTrivia(SyntaxFactory.CarriageReturnLineFeed); - } - - // catch (Exception e) { variableName = e; } - var catchClause = SyntaxFactory.CatchClause() - .WithCatchKeyword(SyntaxFactory.Token(SyntaxKind.CatchKeyword) - .WithLeadingTrivia(SyntaxFactory.CarriageReturnLineFeed, SyntaxFactory.Whitespace(indentString))) - .WithDeclaration(SyntaxFactory.CatchDeclaration( - SyntaxFactory.IdentifierName("Exception"), - SyntaxFactory.Identifier("e"))) - .WithBlock(SyntaxFactory.Block( - SyntaxFactory.ExpressionStatement( - SyntaxFactory.AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - SyntaxFactory.IdentifierName(variableName), - SyntaxFactory.IdentifierName("e"))) - .WithLeadingTrivia(SyntaxFactory.Whitespace(tryBodyIndent)) - .WithTrailingTrivia(SyntaxFactory.CarriageReturnLineFeed)) - .WithOpenBraceToken(SyntaxFactory.Token(SyntaxKind.OpenBraceToken) - .WithTrailingTrivia(SyntaxFactory.CarriageReturnLineFeed)) - .WithCloseBraceToken(SyntaxFactory.Token(SyntaxKind.CloseBraceToken) - .WithLeadingTrivia(SyntaxFactory.Whitespace(indentString)))); - - // try { actionBody; } - var tryBlock = SyntaxFactory.Block(tryBodyStatement) - .WithOpenBraceToken(SyntaxFactory.Token(SyntaxKind.OpenBraceToken) - .WithTrailingTrivia(SyntaxFactory.CarriageReturnLineFeed)) - .WithCloseBraceToken(SyntaxFactory.Token(SyntaxKind.CloseBraceToken) - .WithLeadingTrivia(SyntaxFactory.Whitespace(indentString))); - - var tryStatement = SyntaxFactory.TryStatement() - .WithTryKeyword(SyntaxFactory.Token(SyntaxKind.TryKeyword) - .WithLeadingTrivia(SyntaxFactory.Whitespace(indentString))) - .WithBlock(tryBlock) - .WithCatches(SyntaxFactory.SingletonList(catchClause)) - .WithTrailingTrivia(SyntaxFactory.CarriageReturnLineFeed); - - yield return declarationStatement; - yield return tryStatement; - } - - private static SyntaxNode ConvertTheoryData(Compilation compilation, SyntaxNode root) + protected override CompilationUnitSyntax ApplyFrameworkSpecificConversions( + CompilationUnitSyntax compilationUnit, + SemanticModel semanticModel, + Compilation compilation) { - var currentRoot = root; - foreach (var objectCreationExpressionSyntax in currentRoot.DescendantNodes().OfType()) - { - var type = objectCreationExpressionSyntax switch - { - ObjectCreationExpressionSyntax explicitObjectCreationExpressionSyntax => explicitObjectCreationExpressionSyntax.Type, - ImplicitObjectCreationExpressionSyntax implicitObjectCreationExpressionSyntax => GetTypeFromImplicitCreation(compilation, implicitObjectCreationExpressionSyntax), - _ => null - }; - - while (type is QualifiedNameSyntax qualifiedNameSyntax) - { - type = qualifiedNameSyntax.Right; - } - - if (type is not GenericNameSyntax genericNameSyntax || - genericNameSyntax.Identifier.Text != "TheoryData") - { - continue; - } - - var originalInitializer = objectCreationExpressionSyntax.Initializer!; - var collectionExpressions = originalInitializer.Expressions; - var elementType = genericNameSyntax.TypeArgumentList.Arguments[0]; - - var arrayType = SyntaxFactory.ArrayType(elementType, - SyntaxFactory.SingletonList( - SyntaxFactory.ArrayRankSpecifier( - SyntaxFactory.SingletonSeparatedList( - SyntaxFactory.OmittedArraySizeExpression() - ) - ) - )) - .WithoutTrailingTrivia(); - - var newKeyword = SyntaxFactory.Token(SyntaxKind.NewKeyword) - .WithLeadingTrivia(objectCreationExpressionSyntax.GetLeadingTrivia()) - .WithTrailingTrivia(SyntaxFactory.Space); - - var openBrace = originalInitializer.OpenBraceToken; - if (!openBrace.LeadingTrivia.Any(t => t.IsKind(SyntaxKind.EndOfLineTrivia))) - { - openBrace = openBrace.WithLeadingTrivia( - SyntaxFactory.CarriageReturnLineFeed, - SyntaxFactory.Whitespace(" ")); - } - - var newInitializer = SyntaxFactory.InitializerExpression( - SyntaxKind.ArrayInitializerExpression, - openBrace, - collectionExpressions, - originalInitializer.CloseBraceToken); - - var arrayCreationExpressionSyntax = SyntaxFactory.ArrayCreationExpression( - newKeyword, - arrayType, - newInitializer - ) - .WithTrailingTrivia(objectCreationExpressionSyntax.GetTrailingTrivia()); - - currentRoot = currentRoot.ReplaceNode(objectCreationExpressionSyntax, arrayCreationExpressionSyntax); - } - - foreach (var genericTheoryDataTypeSyntax in currentRoot.DescendantNodes().OfType().Where(x => x.Identifier.Text == "TheoryData")) - { - var enumerableTypeSyntax = SyntaxFactory.GenericName( - SyntaxFactory.Identifier("IEnumerable"), - SyntaxFactory.TypeArgumentList(SyntaxFactory.SeparatedList(genericTheoryDataTypeSyntax.TypeArgumentList.Arguments))) - .WithLeadingTrivia(genericTheoryDataTypeSyntax.GetLeadingTrivia()) - .WithTrailingTrivia(genericTheoryDataTypeSyntax.GetTrailingTrivia()); - - currentRoot = currentRoot.ReplaceNode(genericTheoryDataTypeSyntax, enumerableTypeSyntax); - } - - return currentRoot; + // Not used - two-phase architecture handles all conversions + return compilationUnit; } /// - /// Safely gets the type from an implicit object creation expression using semantic analysis. - /// Returns null if semantic analysis fails (defensive for multi-TFM scenarios). + /// Pass-through rewriter that makes no changes. + /// Used to satisfy abstract method contracts when two-phase architecture is active. /// - private static TypeSyntax? GetTypeFromImplicitCreation(Compilation compilation, ImplicitObjectCreationExpressionSyntax implicitCreation) - { - try - { - var semanticModel = compilation.GetSemanticModel(implicitCreation.SyntaxTree); - var typeInfo = semanticModel.GetTypeInfo(implicitCreation); - - if (typeInfo.Type is null || typeInfo.Type.TypeKind == TypeKind.Error) - { - return null; - } - - return SyntaxFactory.ParseTypeName(typeInfo.Type.ToDisplayString()); - } - catch (InvalidOperationException) - { - // Semantic analysis failed due to invalid compilation state - return null; - } - } - - private static SyntaxNode UpdateInitializeDispose(Compilation compilation, SyntaxNode root) - { - // Always operate on the latest root - var currentRoot = root; - foreach (var classDeclaration in root.DescendantNodes().OfType().ToList()) - { - if (classDeclaration.BaseList is null) - { - continue; - } - - var semanticModel = compilation.GetSemanticModel(classDeclaration.SyntaxTree); - - if (semanticModel.GetDeclaredSymbol(classDeclaration) is not { } symbol) - { - continue; - } - - // Always get the latest node from the current root - var currentClass = currentRoot.DescendantNodes().OfType() - .FirstOrDefault(n => n.SpanStart == classDeclaration.SpanStart && n.Identifier.Text == classDeclaration.Identifier.Text); - - if (currentClass == null) - { - continue; - } - - var newNode = new InitializeDisposeRewriter(symbol).VisitClassDeclaration(currentClass); - - if (!ReferenceEquals(currentClass, newNode)) - { - currentRoot = currentRoot.ReplaceNode(currentClass, newNode); - } - } - - return currentRoot; - } - - private static SyntaxNode RemoveInterfacesAndBaseClasses(Compilation compilation, SyntaxNode root) - { - var currentRoot = root; - foreach (var classDeclaration in root.DescendantNodes().OfType().ToList()) - { - if (classDeclaration.BaseList is null) - { - continue; - } - - var semanticModel = compilation.GetSemanticModel(classDeclaration.SyntaxTree); - - if (semanticModel.GetDeclaredSymbol(classDeclaration) is not { } symbol) - { - continue; - } - - // Always get the latest node from the current root - var currentClass = currentRoot.DescendantNodes().OfType() - .FirstOrDefault(n => n.SpanStart == classDeclaration.SpanStart && n.Identifier.Text == classDeclaration.Identifier.Text); - - if (currentClass == null) - { - continue; - } - - var newNode = new BaseTypeRewriter(symbol).VisitClassDeclaration(currentClass); - - if (!ReferenceEquals(currentClass, newNode)) - { - currentRoot = currentRoot.ReplaceNode(currentClass, newNode); - } - } - - return currentRoot; - } - - private static SyntaxNode RemoveUsingDirectives(SyntaxNode updatedRoot) - { - var compilationUnit = updatedRoot.DescendantNodesAndSelf() - .OfType() - .FirstOrDefault(); - - if (compilationUnit is null) - { - return updatedRoot; - } - - return compilationUnit.WithUsings( - SyntaxFactory.List( - compilationUnit.Usings - .Where(x => x.Name?.ToString().StartsWith("Xunit") is false) - ) - ); - } - - private static SyntaxNode UpdateClassAttributes(Compilation compilation, SyntaxNode root) - { - var rewriter = new XUnitAttributeRewriterInternal(compilation); - return rewriter.Visit(root); - } - - private static string GetSimpleName(AttributeSyntax attributeSyntax) - { - var name = attributeSyntax.Name; - - while (name is not SimpleNameSyntax) - { - name = (name as QualifiedNameSyntax)?.Right; - } - - return name.ToString(); - } - - private static AttributeArgumentListSyntax CreateArgumentListWithAddedArgument( - AttributeArgumentListSyntax existingList, - AttributeArgumentSyntax newArgument) - { - if (existingList.Arguments.Count == 0) - { - return existingList.AddArguments(newArgument); - } - - // Preserve separator trivia by creating a new list with proper separators - var newArguments = new List(existingList.Arguments); - newArguments.Add(newArgument); - - var separators = new List(existingList.Arguments.GetSeparators()); - // Add a comma with trailing space for the new argument - separators.Add(SyntaxFactory.Token(SyntaxKind.CommaToken).WithTrailingTrivia(SyntaxFactory.Space)); - - return SyntaxFactory.AttributeArgumentList( - SyntaxFactory.SeparatedList(newArguments, separators)); - } - - private static IEnumerable ConvertTestAttribute(AttributeSyntax attributeSyntax) - { - yield return SyntaxFactory.Attribute(SyntaxFactory.IdentifierName("Test")); - - if (attributeSyntax.ArgumentList?.Arguments.FirstOrDefault(x => x.NameEquals?.Name.Identifier.ValueText == "Skip") is { } skip) - { - yield return SyntaxFactory.Attribute(SyntaxFactory.IdentifierName("Skip")) - .AddArgumentListArguments(SyntaxFactory.AttributeArgument(skip.Expression)); - } - } - - private static IEnumerable ConvertCollection(Compilation compilation, AttributeSyntax attributeSyntax) - { - var collectionDefinition = GetCollectionAttribute(compilation, attributeSyntax); - - if (collectionDefinition is null) - { - return [attributeSyntax]; - } - - var disableParallelism = - collectionDefinition.ArgumentList?.Arguments.Any(x => x.NameEquals?.Name.Identifier.Text == "DisableParallelization" - && x.Expression is LiteralExpressionSyntax { Token.ValueText: "true" }) ?? false; - - var attributes = new List(); - - if (disableParallelism) - { - attributes.Add(SyntaxFactory.Attribute(SyntaxFactory.ParseName("NotInParallel"))); - } - - var baseListSyntax = collectionDefinition.Parent?.Parent?.ChildNodes().OfType().FirstOrDefault(); - - if (baseListSyntax is null) - { - return attributes; - } - - var collectionFixture = baseListSyntax.Types.Select(x => x.Type).OfType().FirstOrDefault(x => x.Identifier.Text == "ICollectionFixture"); - - if (collectionFixture is null) - { - return attributes; - } - - var type = collectionFixture.TypeArgumentList.Arguments.FirstOrDefault(); - - if (type is null) - { - return attributes; - } - - attributes.Add(SyntaxFactory.Attribute( - SyntaxFactory.GenericName(SyntaxFactory.Identifier("ClassDataSource"), - SyntaxFactory.TypeArgumentList(SyntaxFactory.SingletonSeparatedList(type))), - SyntaxFactory.AttributeArgumentList() - .AddArguments( - SyntaxFactory.AttributeArgument( - nameEquals: SyntaxFactory.NameEquals("Shared"), - nameColon: null, - expression: SyntaxFactory.ParseExpression("SharedType.Keyed") - ), - SyntaxFactory.AttributeArgument( - nameEquals: SyntaxFactory.NameEquals("Key"), - nameColon: null, - expression: GetMethodArgumentName(attributeSyntax) - ) - ).NormalizeWhitespace() - )); - - return attributes; - } - - private static ExpressionSyntax GetMethodArgumentName(AttributeSyntax attributeSyntax) - { - var firstToken = attributeSyntax.ArgumentList?.Arguments.FirstOrDefault()?.GetFirstToken(); - - if (!firstToken.HasValue) - { - return SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, SyntaxFactory.Literal("")); - } - - return SyntaxFactory.ParseExpression(firstToken.Value.Text); - } - - private static AttributeSyntax? GetCollectionAttribute(Compilation compilation, AttributeSyntax attributeSyntax) - { - var firstToken = attributeSyntax.ArgumentList?.Arguments.FirstOrDefault()?.GetFirstToken(); - - if (!firstToken.HasValue) - { - return null; - } - - var collectionName = firstToken.Value.IsKind(SyntaxKind.NameOfKeyword) - ? GetNameFromNameOfToken(firstToken.Value) - : firstToken.Value.ValueText; - - if (collectionName is null) - { - return null; - } - - return compilation.SyntaxTrees - .Select(x => x.GetRoot()) - .SelectMany(x => x.DescendantNodes().OfType()) - .Where(attr => attr.Name.ToString() == "CollectionDefinition") - .FirstOrDefault(x => - { - var syntaxToken = x.ArgumentList?.Arguments.FirstOrDefault()?.GetFirstToken(); - - if (!syntaxToken.HasValue) - { - return false; - } - - var name = syntaxToken.Value.IsKind(SyntaxKind.NameOfKeyword) - ? GetNameFromNameOfToken(syntaxToken.Value) - : syntaxToken.Value.ValueText; - - return name == collectionName; - }); - } - - private static string? GetNameFromNameOfToken(SyntaxToken token) + private class PassThroughRewriter : CSharpSyntaxRewriter { - var expression = SyntaxFactory.ParseExpression(token.Text) as InvocationExpressionSyntax; - - if (expression?.Expression is IdentifierNameSyntax { Identifier.Text: "nameof" } && - expression.ArgumentList.Arguments.FirstOrDefault()?.Expression is IdentifierNameSyntax nameOfArgument) - { - return nameOfArgument.Identifier.Text; - } - - return null; } - // This is the AttributeRewriter for the base class pattern - private class XUnitAttributeRewriter : AttributeRewriter + /// + /// Pass-through attribute rewriter that makes no changes. + /// Used to satisfy abstract method contracts when two-phase architecture is active. + /// + private class PassThroughAttributeRewriter : AttributeRewriter { protected override string FrameworkName => "XUnit"; protected override bool IsFrameworkAttribute(string attributeName) { - return attributeName is "Fact" or "FactAttribute" or "Theory" or "TheoryAttribute" - or "Trait" or "TraitAttribute" or "InlineData" or "InlineDataAttribute" - or "MemberData" or "MemberDataAttribute" or "ClassData" or "ClassDataAttribute" - or "Collection" or "CollectionAttribute" or "CollectionDefinition" or "CollectionDefinitionAttribute"; + return false; } - protected override AttributeArgumentListSyntax? ConvertAttributeArguments(AttributeArgumentListSyntax argumentList, string attributeName) + protected override AttributeArgumentListSyntax? ConvertAttributeArguments( + AttributeArgumentListSyntax argumentList, + string attributeName) { - // XUnit attributes don't need special argument conversion - handled by XUnitAttributeRewriterInternal return argumentList; } } - - private class XUnitAssertionRewriter : AssertionRewriter - { - protected override string FrameworkName => "XUnit"; - - public XUnitAssertionRewriter(SemanticModel semanticModel) : base(semanticModel) - { - } - - protected override bool IsFrameworkAssertionNamespace(string namespaceName) - { - return namespaceName.Equals("Xunit", StringComparison.OrdinalIgnoreCase) || - namespaceName.StartsWith("Xunit.", StringComparison.OrdinalIgnoreCase); - } - - protected override bool IsKnownAssertionTypeBySyntax(string targetType, string methodName) - { - // XUnit assertion type that can be detected by syntax - return targetType == "Assert"; - } - - protected override ExpressionSyntax? ConvertAssertionIfNeeded(InvocationExpressionSyntax invocation) - { - if (!IsFrameworkAssertion(invocation)) - { - return null; - } - - // Handle both simple (Assert.Equal) and qualified (Xunit.Assert.Equal) names - if (invocation.Expression is MemberAccessExpressionSyntax memberAccess) - { - var typeName = GetSimpleTypeName(memberAccess.Expression); - if (typeName == "Assert") - { - return ConvertXUnitAssertion(invocation, memberAccess.Name.Identifier.Text, memberAccess.Name); - } - } - - return null; - } - - /// - /// Extracts the simple type name from an expression. - /// Handles both simple identifiers and qualified names like "Xunit.Assert". - /// - private static string GetSimpleTypeName(ExpressionSyntax expression) - { - return expression switch - { - IdentifierNameSyntax identifier => identifier.Identifier.Text, - MemberAccessExpressionSyntax memberAccess => memberAccess.Name.Identifier.Text, - _ => expression.ToString() - }; - } - - private ExpressionSyntax? ConvertXUnitAssertion(InvocationExpressionSyntax invocation, string methodName, SimpleNameSyntax nameNode) - { - var arguments = invocation.ArgumentList.Arguments; - - return methodName switch - { - // Equality assertions - check for comparer overloads - "Equal" when arguments.Count >= 3 && IsLikelyComparerArgument(arguments[2]) => - CreateEqualWithComparerComment(arguments), - "Equal" when arguments.Count >= 2 => - CreateTUnitAssertion("IsEqualTo", arguments[1].Expression, arguments[0]), - "NotEqual" when arguments.Count >= 3 && IsLikelyComparerArgument(arguments[2]) => - CreateNotEqualWithComparerComment(arguments), - "NotEqual" when arguments.Count >= 2 => - CreateTUnitAssertion("IsNotEqualTo", arguments[1].Expression, arguments[0]), - - // Boolean assertions - "True" when arguments.Count >= 2 => - CreateTUnitAssertionWithMessage("IsTrue", arguments[0].Expression, arguments[1].Expression), - "True" when arguments.Count >= 1 => - CreateTUnitAssertion("IsTrue", arguments[0].Expression), - "False" when arguments.Count >= 2 => - CreateTUnitAssertionWithMessage("IsFalse", arguments[0].Expression, arguments[1].Expression), - "False" when arguments.Count >= 1 => - CreateTUnitAssertion("IsFalse", arguments[0].Expression), - - // Null assertions - "Null" when arguments.Count >= 1 => - CreateTUnitAssertion("IsNull", arguments[0].Expression), - "NotNull" when arguments.Count >= 1 => - CreateTUnitAssertion("IsNotNull", arguments[0].Expression), - - // Reference assertions - "Same" when arguments.Count >= 2 => - CreateTUnitAssertion("IsSameReferenceAs", arguments[1].Expression, arguments[0]), - "NotSame" when arguments.Count >= 2 => - CreateTUnitAssertion("IsNotSameReferenceAs", arguments[1].Expression, arguments[0]), - - // String/Collection contains - use collection assertion for proper overload resolution - "Contains" when arguments.Count >= 2 => - CreateTUnitCollectionAssertion("Contains", arguments[1].Expression, arguments[0]), - "DoesNotContain" when arguments.Count >= 2 => - CreateTUnitCollectionAssertion("DoesNotContain", arguments[1].Expression, arguments[0]), - "StartsWith" when arguments.Count >= 2 => - CreateTUnitAssertion("StartsWith", arguments[1].Expression, arguments[0]), - "EndsWith" when arguments.Count >= 2 => - CreateTUnitAssertion("EndsWith", arguments[1].Expression, arguments[0]), - - // Empty/Not empty - use collection assertion for proper overload resolution - "Empty" when arguments.Count >= 1 => - CreateTUnitCollectionAssertion("IsEmpty", arguments[0].Expression), - "NotEmpty" when arguments.Count >= 1 => - CreateTUnitCollectionAssertion("IsNotEmpty", arguments[0].Expression), - - // Exception assertions - "Throws" => ConvertThrows(invocation, nameNode), - "ThrowsAsync" => ConvertThrowsAsync(invocation, nameNode), - "ThrowsAny" => ConvertThrowsAny(invocation, nameNode), - "ThrowsAnyAsync" => ConvertThrowsAnyAsync(invocation, nameNode), - - // Type assertions - "IsType" => ConvertIsType(invocation, nameNode), - "IsNotType" => ConvertIsNotType(invocation, nameNode), - "IsAssignableFrom" => ConvertIsAssignableFrom(invocation, nameNode), - - // Range assertions - "InRange" when arguments.Count >= 3 => - CreateTUnitAssertion("IsInRange", arguments[0].Expression, arguments[1], arguments[2]), - "NotInRange" when arguments.Count >= 3 => - CreateTUnitAssertion("IsNotInRange", arguments[0].Expression, arguments[1], arguments[2]), - - // Collection assertions - use collection assertion for proper overload resolution - "Single" when arguments.Count >= 1 => - CreateTUnitCollectionAssertion("HasSingleItem", arguments[0].Expression), - "All" when arguments.Count >= 2 => - CreateAllAssertion(arguments[0].Expression, arguments[1].Expression), - - // Subset/superset - use collection assertion for proper overload resolution - "Subset" when arguments.Count >= 2 => - CreateTUnitCollectionAssertion("IsSubsetOf", arguments[0].Expression, arguments[1]), - "Superset" when arguments.Count >= 2 => - CreateTUnitCollectionAssertion("IsSupersetOf", arguments[0].Expression, arguments[1]), - "ProperSubset" when arguments.Count >= 2 => - CreateProperSubsetWithTodo(arguments), - "ProperSuperset" when arguments.Count >= 2 => - CreateProperSupersetWithTodo(arguments), - - // Unique items - use collection assertion for proper overload resolution - "Distinct" when arguments.Count >= 1 => - CreateTUnitCollectionAssertion("HasDistinctItems", arguments[0].Expression), - - // Equivalent (order independent) - use collection assertion for proper overload resolution - "Equivalent" when arguments.Count >= 2 => - CreateTUnitCollectionAssertion("IsEquivalentTo", arguments[1].Expression, arguments[0]), - - // Regex assertions - "Matches" when arguments.Count >= 2 => - CreateTUnitAssertion("Matches", arguments[1].Expression, arguments[0]), - "DoesNotMatch" when arguments.Count >= 2 => - CreateTUnitAssertion("DoesNotMatch", arguments[1].Expression, arguments[0]), - - // Collection with inspectors - complex, needs TODO - "Collection" when arguments.Count >= 2 => - CreateCollectionWithTodo(arguments), - - // PropertyChanged - not supported in TUnit - "PropertyChanged" when arguments.Count >= 3 => - CreatePropertyChangedTodo(arguments), - "PropertyChangedAsync" when arguments.Count >= 3 => - CreatePropertyChangedTodo(arguments), - - // Raises events - not supported in TUnit - "Raises" => CreateRaisesTodo(arguments), - "RaisesAsync" => CreateRaisesTodo(arguments), - "RaisesAny" => CreateRaisesTodo(arguments), - "RaisesAnyAsync" => CreateRaisesTodo(arguments), - - _ => null - }; - } - - private ExpressionSyntax CreateEqualWithComparerComment(SeparatedSyntaxList arguments) - { - var result = CreateTUnitAssertion("IsEqualTo", arguments[1].Expression, arguments[0]); - return result.WithLeadingTrivia( - SyntaxFactory.Comment("// TODO: TUnit migration - custom comparer was used. Consider using Assert.That(...).IsEquivalentTo() or a custom condition."), - SyntaxFactory.EndOfLine("\n")); - } - - private ExpressionSyntax CreateNotEqualWithComparerComment(SeparatedSyntaxList arguments) - { - var result = CreateTUnitAssertion("IsNotEqualTo", arguments[1].Expression, arguments[0]); - return result.WithLeadingTrivia( - SyntaxFactory.Comment("// TODO: TUnit migration - custom comparer was used. Consider using a custom condition."), - SyntaxFactory.EndOfLine("\n")); - } - - private ExpressionSyntax CreateCollectionWithTodo(SeparatedSyntaxList arguments) - { - // Assert.Collection(collection, inspector1, inspector2, ...) has no direct TUnit equivalent - // Convert to HasCount check and add TODO for manual inspector conversion - var collection = arguments[0].Expression; - var inspectorCount = arguments.Count - 1; - - var result = CreateTUnitCollectionAssertion("HasCount", collection, - SyntaxFactory.Argument( - SyntaxFactory.LiteralExpression( - SyntaxKind.NumericLiteralExpression, - SyntaxFactory.Literal(inspectorCount)))); - - // Just add TODO comment and newline - indentation will be handled by VisitInvocationExpression - return result.WithLeadingTrivia( - SyntaxFactory.Comment("// TODO: TUnit migration - Assert.Collection had element inspectors. Manually add assertions for each element."), - SyntaxFactory.EndOfLine("\n")); - } - - private ExpressionSyntax CreatePropertyChangedTodo(SeparatedSyntaxList arguments) - { - // Assert.PropertyChanged(object, propertyName, action) - TUnit doesn't have this - // Create a placeholder that executes the action and add TODO - var action = arguments.Count > 2 ? arguments[2].Expression : arguments[0].Expression; - - // Create: action() with TODO comment - var invocation = action is LambdaExpressionSyntax - ? (ExpressionSyntax)SyntaxFactory.InvocationExpression( - SyntaxFactory.ParenthesizedExpression(action)) - : SyntaxFactory.InvocationExpression(action); - - return invocation.WithLeadingTrivia( - SyntaxFactory.Comment("// TODO: TUnit migration - PropertyChanged assertion not supported. Implement INotifyPropertyChanged testing manually."), - SyntaxFactory.EndOfLine("\n")); - } - - private ExpressionSyntax CreateRaisesTodo(SeparatedSyntaxList arguments) - { - // Assert.Raises(attach, detach, action) - TUnit doesn't have this - // Create placeholder with TODO - var action = arguments.Count > 2 ? arguments[2].Expression : arguments[0].Expression; - - var invocation = action is LambdaExpressionSyntax - ? (ExpressionSyntax)SyntaxFactory.InvocationExpression( - SyntaxFactory.ParenthesizedExpression(action)) - : SyntaxFactory.InvocationExpression(action); - - return invocation.WithLeadingTrivia( - SyntaxFactory.Comment("// TODO: TUnit migration - Raises assertion not supported. Implement event testing manually."), - SyntaxFactory.EndOfLine("\n")); - } - - private ExpressionSyntax CreateProperSubsetWithTodo(SeparatedSyntaxList arguments) - { - // ProperSubset means strict subset (not equal to superset) - // TUnit's IsSubsetOf doesn't distinguish between proper/improper - var result = CreateTUnitCollectionAssertion("IsSubsetOf", arguments[0].Expression, arguments[1]); - return result.WithLeadingTrivia( - SyntaxFactory.Comment("// TODO: TUnit migration - ProperSubset requires strict subset (not equal). Add additional assertion if needed."), - SyntaxFactory.EndOfLine("\n")); - } - - private ExpressionSyntax CreateProperSupersetWithTodo(SeparatedSyntaxList arguments) - { - // ProperSuperset means strict superset (not equal to subset) - // TUnit's IsSupersetOf doesn't distinguish between proper/improper - var result = CreateTUnitCollectionAssertion("IsSupersetOf", arguments[0].Expression, arguments[1]); - return result.WithLeadingTrivia( - SyntaxFactory.Comment("// TODO: TUnit migration - ProperSuperset requires strict superset (not equal). Add additional assertion if needed."), - SyntaxFactory.EndOfLine("\n")); - } - - private ExpressionSyntax CreateAllAssertion(ExpressionSyntax collection, ExpressionSyntax actionOrPredicate) - { - // Assert.All(collection, action) -> await Assert.That(collection).All(predicate) - // Try to extract a simple predicate from the action if possible - - var predicateExpression = TryConvertActionToPredicate(actionOrPredicate); - - // Use CreateTUnitCollectionAssertion with the predicate as an argument - return CreateTUnitCollectionAssertion("All", collection, SyntaxFactory.Argument(predicateExpression)); - } - - private ExpressionSyntax TryConvertActionToPredicate(ExpressionSyntax actionExpression) - { - // Try to convert xUnit action patterns to TUnit predicates - // Pattern: item => Assert.True(item > 0) -> item => item > 0 - // Pattern: item => Assert.False(item < 0) -> item => !(item < 0) - // Pattern: item => Assert.NotNull(item) -> item => item != null - // Pattern: item => Assert.Null(item) -> item => item == null - - if (actionExpression is SimpleLambdaExpressionSyntax simpleLambda) - { - var parameter = simpleLambda.Parameter; - var body = simpleLambda.Body; - - // Check if body is an xUnit assertion invocation - if (body is InvocationExpressionSyntax invocation && - invocation.Expression is MemberAccessExpressionSyntax memberAccess && - memberAccess.Expression is IdentifierNameSyntax { Identifier.Text: "Assert" }) - { - var methodName = memberAccess.Name.Identifier.Text; - var args = invocation.ArgumentList.Arguments; - - ExpressionSyntax? predicateBody = methodName switch - { - "True" when args.Count >= 1 => args[0].Expression, - "False" when args.Count >= 1 => SyntaxFactory.PrefixUnaryExpression( - SyntaxKind.LogicalNotExpression, - SyntaxFactory.ParenthesizedExpression(args[0].Expression)), - "NotNull" when args.Count >= 1 => SyntaxFactory.BinaryExpression( - SyntaxKind.NotEqualsExpression, - args[0].Expression, - SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)), - "Null" when args.Count >= 1 => SyntaxFactory.BinaryExpression( - SyntaxKind.EqualsExpression, - args[0].Expression, - SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)), - _ => null - }; - - if (predicateBody != null) - { - return SyntaxFactory.SimpleLambdaExpression(parameter, predicateBody) - .WithArrowToken(SyntaxFactory.Token(SyntaxKind.EqualsGreaterThanToken) - .WithTrailingTrivia(SyntaxFactory.Space)); - } - } - } - else if (actionExpression is ParenthesizedLambdaExpressionSyntax parenLambda) - { - // Handle (item) => Assert.True(expr) pattern - if (parenLambda.ParameterList.Parameters.Count == 1) - { - var parameter = parenLambda.ParameterList.Parameters[0]; - var body = parenLambda.Body; - - if (body is InvocationExpressionSyntax invocation && - invocation.Expression is MemberAccessExpressionSyntax memberAccess && - memberAccess.Expression is IdentifierNameSyntax { Identifier.Text: "Assert" }) - { - var methodName = memberAccess.Name.Identifier.Text; - var args = invocation.ArgumentList.Arguments; - - ExpressionSyntax? predicateBody = methodName switch - { - "True" when args.Count >= 1 => args[0].Expression, - "False" when args.Count >= 1 => SyntaxFactory.PrefixUnaryExpression( - SyntaxKind.LogicalNotExpression, - SyntaxFactory.ParenthesizedExpression(args[0].Expression)), - "NotNull" when args.Count >= 1 => SyntaxFactory.BinaryExpression( - SyntaxKind.NotEqualsExpression, - args[0].Expression, - SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)), - "Null" when args.Count >= 1 => SyntaxFactory.BinaryExpression( - SyntaxKind.EqualsExpression, - args[0].Expression, - SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)), - _ => null - }; - - if (predicateBody != null) - { - // Convert to simple lambda for cleaner output - return SyntaxFactory.SimpleLambdaExpression( - SyntaxFactory.Parameter(parameter.Identifier), - predicateBody) - .WithArrowToken(SyntaxFactory.Token(SyntaxKind.EqualsGreaterThanToken) - .WithTrailingTrivia(SyntaxFactory.Space)); - } - } - } - } - - // Fallback: return the original expression as-is - // This will likely cause a compilation error, prompting manual conversion - return actionExpression; - } - - private ExpressionSyntax ConvertThrowsAny(InvocationExpressionSyntax invocation, SimpleNameSyntax nameNode) - { - // xUnit Assert.ThrowsAny(Action) -> TUnit Assert.Throws(Action) - // Both are synchronous - ThrowsAny accepts derived types, TUnit's Throws does too - if (nameNode is GenericNameSyntax genericName) - { - var exceptionType = genericName.TypeArgumentList.Arguments[0]; - var action = invocation.ArgumentList.Arguments[0].Expression; - - // Keep it synchronous - TUnit's Assert.Throws(Action) accepts derived types - return SyntaxFactory.InvocationExpression( - SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - SyntaxFactory.IdentifierName("Assert"), - SyntaxFactory.GenericName("Throws") - .WithTypeArgumentList( - SyntaxFactory.TypeArgumentList( - SyntaxFactory.SingletonSeparatedList(exceptionType) - ) - ) - ), - SyntaxFactory.ArgumentList( - SyntaxFactory.SingletonSeparatedList( - SyntaxFactory.Argument(action) - ) - ) - ); - } - - return CreateTUnitAssertion("Throws", invocation.ArgumentList.Arguments[0].Expression); - } - - private ExpressionSyntax ConvertThrowsAnyAsync(InvocationExpressionSyntax invocation, SimpleNameSyntax nameNode) - { - // xUnit Assert.ThrowsAnyAsync(Func) -> await Assert.ThrowsAsync(Func) - // ThrowsAnyAsync accepts derived types, TUnit's ThrowsAsync does too - if (nameNode is GenericNameSyntax genericName) - { - var exceptionType = genericName.TypeArgumentList.Arguments[0]; - var action = invocation.ArgumentList.Arguments[0].Expression; - - var invocationExpression = SyntaxFactory.InvocationExpression( - SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - SyntaxFactory.IdentifierName("Assert"), - SyntaxFactory.GenericName("ThrowsAsync") - .WithTypeArgumentList( - SyntaxFactory.TypeArgumentList( - SyntaxFactory.SingletonSeparatedList(exceptionType) - ) - ) - ), - SyntaxFactory.ArgumentList( - SyntaxFactory.SingletonSeparatedList( - SyntaxFactory.Argument(action) - ) - ) - ); - - var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword) - .WithTrailingTrivia(SyntaxFactory.Space); - return SyntaxFactory.AwaitExpression(awaitKeyword, invocationExpression); - } - - return CreateTUnitAssertion("ThrowsAsync", invocation.ArgumentList.Arguments[0].Expression); - } - - private ExpressionSyntax ConvertIsNotType(InvocationExpressionSyntax invocation, SimpleNameSyntax nameNode) - { - // Assert.IsNotType(value) -> await Assert.That(value).IsNotTypeOf() - if (nameNode is GenericNameSyntax genericName) - { - var expectedType = genericName.TypeArgumentList.Arguments[0]; - var value = invocation.ArgumentList.Arguments[0].Expression; - - var assertThatInvocation = SyntaxFactory.InvocationExpression( - SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - SyntaxFactory.IdentifierName("Assert"), - SyntaxFactory.IdentifierName("That") - ), - SyntaxFactory.ArgumentList( - SyntaxFactory.SingletonSeparatedList( - SyntaxFactory.Argument(value) - ) - ) - ); - - var methodAccess = SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - assertThatInvocation, - SyntaxFactory.GenericName("IsNotTypeOf") - .WithTypeArgumentList( - SyntaxFactory.TypeArgumentList( - SyntaxFactory.SingletonSeparatedList(expectedType) - ) - ) - ); - - var fullInvocation = SyntaxFactory.InvocationExpression(methodAccess, SyntaxFactory.ArgumentList()); - var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword) - .WithTrailingTrivia(SyntaxFactory.Space); - return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation); - } - - return CreateTUnitAssertion("IsNotTypeOf", invocation.ArgumentList.Arguments[0].Expression); - } - - private ExpressionSyntax ConvertThrows(InvocationExpressionSyntax invocation, SimpleNameSyntax nameNode) - { - // xUnit Assert.Throws(Action) -> TUnit Assert.Throws(Action) - // Both are synchronous and return the exception directly - // NO async conversion needed - TUnit has a sync version that matches xUnit's signature - if (nameNode is GenericNameSyntax genericName) - { - var exceptionType = genericName.TypeArgumentList.Arguments[0]; - var action = invocation.ArgumentList.Arguments[0].Expression; - - // Keep it synchronous - TUnit's Assert.Throws(Action) returns TException directly - return SyntaxFactory.InvocationExpression( - SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - SyntaxFactory.IdentifierName("Assert"), - SyntaxFactory.GenericName("Throws") - .WithTypeArgumentList( - SyntaxFactory.TypeArgumentList( - SyntaxFactory.SingletonSeparatedList(exceptionType) - ) - ) - ), - SyntaxFactory.ArgumentList( - SyntaxFactory.SingletonSeparatedList( - SyntaxFactory.Argument(action) - ) - ) - ); - } - - // Fallback for non-generic Throws - return CreateTUnitAssertion("Throws", invocation.ArgumentList.Arguments[0].Expression); - } - - private ExpressionSyntax ConvertThrowsAsync(InvocationExpressionSyntax invocation, SimpleNameSyntax nameNode) - { - // Assert.ThrowsAsync(asyncAction) -> await Assert.ThrowsAsync(asyncAction) - if (nameNode is GenericNameSyntax genericName) - { - var exceptionType = genericName.TypeArgumentList.Arguments[0]; - var action = invocation.ArgumentList.Arguments[0].Expression; - - var invocationExpression = SyntaxFactory.InvocationExpression( - SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - SyntaxFactory.IdentifierName("Assert"), - SyntaxFactory.GenericName("ThrowsAsync") - .WithTypeArgumentList( - SyntaxFactory.TypeArgumentList( - SyntaxFactory.SingletonSeparatedList(exceptionType) - ) - ) - ), - SyntaxFactory.ArgumentList( - SyntaxFactory.SingletonSeparatedList( - SyntaxFactory.Argument(action) - ) - ) - ); - - var awaitKeyword2 = SyntaxFactory.Token(SyntaxKind.AwaitKeyword) - .WithTrailingTrivia(SyntaxFactory.Space); - return SyntaxFactory.AwaitExpression(awaitKeyword2, invocationExpression); - } - - return CreateTUnitAssertion("ThrowsAsync", invocation.ArgumentList.Arguments[0].Expression); - } - - private ExpressionSyntax ConvertIsType(InvocationExpressionSyntax invocation, SimpleNameSyntax nameNode) - { - // Assert.IsType(value) -> await Assert.That(value).IsTypeOf() - if (nameNode is GenericNameSyntax genericName) - { - var expectedType = genericName.TypeArgumentList.Arguments[0]; - var value = invocation.ArgumentList.Arguments[0].Expression; - - var assertThatInvocation = SyntaxFactory.InvocationExpression( - SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - SyntaxFactory.IdentifierName("Assert"), - SyntaxFactory.IdentifierName("That") - ), - SyntaxFactory.ArgumentList( - SyntaxFactory.SingletonSeparatedList( - SyntaxFactory.Argument(value) - ) - ) - ); - - var methodAccess = SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - assertThatInvocation, - SyntaxFactory.GenericName("IsTypeOf") - .WithTypeArgumentList( - SyntaxFactory.TypeArgumentList( - SyntaxFactory.SingletonSeparatedList(expectedType) - ) - ) - ); - - var fullInvocation = SyntaxFactory.InvocationExpression(methodAccess, SyntaxFactory.ArgumentList()); - var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword) - .WithTrailingTrivia(SyntaxFactory.Space); - return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation); - } - - return CreateTUnitAssertion("IsTypeOf", invocation.ArgumentList.Arguments[0].Expression); - } - - private ExpressionSyntax ConvertIsAssignableFrom(InvocationExpressionSyntax invocation, SimpleNameSyntax nameNode) - { - // Assert.IsAssignableFrom(value) -> await Assert.That(value).IsAssignableTo() - if (nameNode is GenericNameSyntax genericName) - { - var expectedType = genericName.TypeArgumentList.Arguments[0]; - var value = invocation.ArgumentList.Arguments[0].Expression; - - var assertThatInvocation = SyntaxFactory.InvocationExpression( - SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - SyntaxFactory.IdentifierName("Assert"), - SyntaxFactory.IdentifierName("That") - ), - SyntaxFactory.ArgumentList( - SyntaxFactory.SingletonSeparatedList( - SyntaxFactory.Argument(value) - ) - ) - ); - - var methodAccess = SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - assertThatInvocation, - SyntaxFactory.GenericName("IsAssignableTo") - .WithTypeArgumentList( - SyntaxFactory.TypeArgumentList( - SyntaxFactory.SingletonSeparatedList(expectedType) - ) - ) - ); - - var fullInvocation = SyntaxFactory.InvocationExpression(methodAccess, SyntaxFactory.ArgumentList()); - var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword) - .WithTrailingTrivia(SyntaxFactory.Space); - return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation); - } - - return CreateTUnitAssertion("IsAssignableTo", invocation.ArgumentList.Arguments[0].Expression); - } - } - - // Internal rewriter used by ApplyFrameworkSpecificConversions with compilation access - private class XUnitAttributeRewriterInternal : CSharpSyntaxRewriter - { - private readonly Compilation _compilation; - - public XUnitAttributeRewriterInternal(Compilation compilation) - { - _compilation = compilation; - } - - public override SyntaxNode VisitAttributeList(AttributeListSyntax node) - { - var newAttributes = new List(); - var separators = new List(); - - // Preserve the original separators (commas with their trivia/spacing) - var originalSeparators = node.Attributes.GetSeparators().ToList(); - - for (int i = 0; i < node.Attributes.Count; i++) - { - var attr = node.Attributes[i]; - var name = GetSimpleName(attr); - - var converted = name switch - { - "Fact" or "FactAttribute" or "Theory" or "TheoryAttribute" => ConvertTestAttribute(attr), - "Trait" or "TraitAttribute" => [SyntaxFactory.Attribute(SyntaxFactory.IdentifierName("Property"), attr.ArgumentList)], - "InlineData" or "InlineDataAttribute" => [SyntaxFactory.Attribute(SyntaxFactory.IdentifierName("Arguments"), attr.ArgumentList)], - "MemberData" or "MemberDataAttribute" => [SyntaxFactory.Attribute(SyntaxFactory.IdentifierName("MethodDataSource"), attr.ArgumentList)], - "ClassData" or "ClassDataAttribute" => - [ - SyntaxFactory.Attribute( - SyntaxFactory.IdentifierName("MethodDataSource"), - CreateArgumentListWithAddedArgument(attr.ArgumentList ?? SyntaxFactory.AttributeArgumentList(), - SyntaxFactory.AttributeArgument(SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, - SyntaxFactory.Literal("GetEnumerator"))))) - ], - "Collection" or "CollectionAttribute" => ConvertCollection(_compilation, attr), - "CollectionDefinition" or "CollectionDefinitionAttribute" => [SyntaxFactory.Attribute( - SyntaxFactory.QualifiedName(SyntaxFactory.IdentifierName("System"), - SyntaxFactory.IdentifierName("Obsolete")))], - _ => [attr] - }; - - int attributesBeforeConversion = newAttributes.Count; - newAttributes.AddRange(converted); - int attributesAfterConversion = newAttributes.Count; - - // Add separators for the newly added attributes - // If we added more than one attribute, add separators between them - for (int j = attributesBeforeConversion; j < attributesAfterConversion - 1; j++) - { - // Use the original separator if available, otherwise create one with space - var separator = i < originalSeparators.Count - ? originalSeparators[i] - : SyntaxFactory.Token(SyntaxKind.CommaToken).WithTrailingTrivia(SyntaxFactory.Space); - separators.Add(separator); - } - - // Add the original separator after this group of attributes if it exists - if (i < originalSeparators.Count && attributesAfterConversion > attributesBeforeConversion) - { - separators.Add(originalSeparators[i]); - } - } - - if (node.Attributes.SequenceEqual(newAttributes)) - { - return node; - } - - // Create separated list with preserved separators - return SyntaxFactory.AttributeList( - SyntaxFactory.SeparatedList(newAttributes, separators)) - .WithLeadingTrivia(node.GetLeadingTrivia()) - .WithTrailingTrivia(node.GetTrailingTrivia()); - } - } - - private class BaseTypeRewriter(INamedTypeSymbol namedTypeSymbol) : CSharpSyntaxRewriter - { - public override SyntaxNode VisitClassDeclaration(ClassDeclarationSyntax node) - { - if (node.BaseList is null) - { - return node; - } - - INamedTypeSymbol[] types = namedTypeSymbol.BaseType != null && namedTypeSymbol.BaseType.SpecialType != SpecialType.System_Object - ? [namedTypeSymbol.BaseType, .. namedTypeSymbol.AllInterfaces] - : [.. namedTypeSymbol.AllInterfaces]; - - var classFixturesToConvert = types - .Where(x => x.Name == "IClassFixture" && x.ContainingNamespace.Name.StartsWith("Xunit")) - .Select(x => SyntaxFactory.Attribute( - SyntaxFactory.GenericName(SyntaxFactory.ParseToken("ClassDataSource"), SyntaxFactory.TypeArgumentList(SyntaxFactory.SingletonSeparatedList(SyntaxFactory.ParseTypeName(x.TypeArguments.First().ToDisplayString())))).WithoutTrailingTrivia(), - SyntaxFactory.AttributeArgumentList() - .AddArguments( - SyntaxFactory.AttributeArgument( - nameEquals: SyntaxFactory.NameEquals("Shared"), - nameColon: null, - expression: SyntaxFactory.ParseExpression("SharedType.PerClass") - ) - ) - ).WithLeadingTrivia(SyntaxFactory.ElasticMarker)) - .ToList(); - - if (classFixturesToConvert.Count > 0) - { - node = node.AddAttributeLists(SyntaxFactory.AttributeList(SyntaxFactory.SeparatedList(classFixturesToConvert))); - } - - var newBaseList = types.Where(x => !x.ContainingNamespace.Name.StartsWith("Xunit")) - .Select(x => SyntaxFactory.SimpleBaseType(SyntaxFactory.ParseTypeName(x.ToDisplayString()))) - .ToList(); - - if (newBaseList.Count == 0) - { - // When removing the entire base list, preserve the trailing trivia - // The base list's trailing trivia typically contains the newline before the opening brace - var baseListTrailingTrivia = node.BaseList.GetTrailingTrivia(); - - // Apply the trivia to the element before the base list (parameter list or identifier) - // REPLACE the trailing trivia rather than adding to it to avoid extra spaces - if (node.ParameterList != null) - { - node = node.WithParameterList( - node.ParameterList.WithTrailingTrivia(baseListTrailingTrivia)); - } - else - { - node = node.WithIdentifier( - node.Identifier.WithTrailingTrivia(baseListTrailingTrivia)); - } - - return node.WithBaseList(null); - } - - var baseListSyntax = node.BaseList!.WithTypes(SyntaxFactory.SeparatedList(newBaseList)); - - return node.WithBaseList(baseListSyntax); - } - } - - private class InitializeDisposeRewriter(INamedTypeSymbol namedTypeSymbol) : CSharpSyntaxRewriter - { - public override SyntaxNode VisitClassDeclaration(ClassDeclarationSyntax node) - { - if (node.BaseList is null) - { - return node; - } - - var interfaces = namedTypeSymbol.Interfaces - .Where(x => x.ToDisplayString(DisplayFormats.FullyQualifiedGenericWithGlobalPrefix) is "global::Xunit.IAsyncLifetime" or "global::System.IAsyncDisposable" or "global::System.IDisposable") - .ToArray(); - - if (interfaces.Length == 0) - { - return node; - } - - var hasAsyncLifetime = interfaces.Any(x => x.ToDisplayString(DisplayFormats.FullyQualifiedGenericWithGlobalPrefix) == "global::Xunit.IAsyncLifetime"); - var hasAsyncDisposable = interfaces.Any(x => x.ToDisplayString(DisplayFormats.FullyQualifiedGenericWithGlobalPrefix) == "global::System.IAsyncDisposable"); - var hasDisposable = interfaces.Any(x => x.ToDisplayString(DisplayFormats.FullyQualifiedGenericWithGlobalPrefix) == "global::System.IDisposable"); - - var isTestClass = namedTypeSymbol - .GetMembers() - .OfType() - .Any(m => m.GetAttributes() - .Any(x => x.AttributeClass?.ToDisplayString(DisplayFormats.FullyQualifiedGenericWithGlobalPrefix) is "global::Xunit.FactAttribute" - or "global::Xunit.TheoryAttribute") - ); - - if (isTestClass) - { - // Collect all replacements first, then apply them together - var replacements = new Dictionary(); - - if (hasAsyncLifetime && GetInitializeMethod(node) is { } initializeMethod) - { - var attributeList = SyntaxFactory.AttributeList( - SyntaxFactory.SingletonSeparatedList( - SyntaxFactory.Attribute(SyntaxFactory.ParseName("Before"), SyntaxFactory.ParseAttributeArgumentList("(Test)")))); - - var newMethod = initializeMethod - .WithReturnType(SyntaxFactory.ParseTypeName("Task").WithTrailingTrivia(SyntaxFactory.Space)) - .WithAttributeLists(SyntaxFactory.SingletonList(attributeList)); - - replacements[initializeMethod] = newMethod; - } - - if ((hasAsyncLifetime || hasAsyncDisposable) && GetDisposeAsyncMethod(node) is { } disposeAsyncMethod) - { - var attributeList = SyntaxFactory.AttributeList( - SyntaxFactory.SingletonSeparatedList( - SyntaxFactory.Attribute(SyntaxFactory.ParseName("After"), SyntaxFactory.ParseAttributeArgumentList("(Test)")))); - - var newMethod = disposeAsyncMethod - .WithReturnType(SyntaxFactory.ParseTypeName("Task").WithTrailingTrivia(SyntaxFactory.Space)) - .WithAttributeLists(SyntaxFactory.SingletonList(attributeList)); - - replacements[disposeAsyncMethod] = newMethod; - } - - if (hasDisposable && GetDisposeMethod(node) is { } disposeMethod) - { - var attributeList = SyntaxFactory.AttributeList( - SyntaxFactory.SingletonSeparatedList( - SyntaxFactory.Attribute(SyntaxFactory.ParseName("After"), SyntaxFactory.ParseAttributeArgumentList("(Test)")))); - - var newMethod = disposeMethod - .WithAttributeLists(SyntaxFactory.SingletonList(attributeList)); - - replacements[disposeMethod] = newMethod; - } - - // Apply all replacements at once - if (replacements.Count > 0) - { - node = node.ReplaceNodes(replacements.Keys, (oldNode, _) => replacements[oldNode]); - } - - // Reorder methods: Test methods should come first, then Before/After methods - var testMethods = node.Members - .OfType() - .Where(m => m.AttributeLists.Any(al => al.Attributes.Any(a => - a.Name.ToString() is "Test" or "Fact" or "Theory"))) - .ToList(); - - var beforeAfterMethods = node.Members - .OfType() - .Where(m => m.AttributeLists.Any(al => al.Attributes.Any(a => - a.Name.ToString() is "Before" or "After"))) - .ToList(); - - var otherMembers = node.Members - .Except(testMethods.Cast()) - .Except(beforeAfterMethods.Cast()) - .ToList(); - - // If we have both test methods and before/after methods, reorder - if (testMethods.Count > 0 && beforeAfterMethods.Count > 0) - { - // Normalize trivia: all members should have blank line before them (except first) - var allMethodsToReorder = new List(); - allMethodsToReorder.AddRange(testMethods); - allMethodsToReorder.AddRange(beforeAfterMethods); - - var normalizedMethods = allMethodsToReorder.Select((m, i) => - { - // Strip existing leading and trailing trivia, then set normalized trivia - var strippedMethod = m.WithLeadingTrivia().WithTrailingTrivia(); - - // Check if method has attributes - var hasAttributes = strippedMethod.AttributeLists.Count > 0; - - if (hasAttributes) - { - // For methods with attributes, we need to: - // 1. Set trivia on the attribute's first token - // 2. Strip the attribute list's trailing trivia - // 3. Set trivia on the first modifier token - - var firstToken = strippedMethod.GetFirstToken(); // This is the '[' token - - if (i == 0) - { - // First method: just indentation on attribute - strippedMethod = strippedMethod.ReplaceToken( - firstToken, - firstToken.WithLeadingTrivia(SyntaxFactory.Whitespace(" "))); - } - else - { - // Subsequent methods: blank line + indentation on attribute - strippedMethod = strippedMethod.ReplaceToken( - firstToken, - firstToken.WithLeadingTrivia( - SyntaxFactory.CarriageReturnLineFeed, - SyntaxFactory.Whitespace(" "))); - } - - // Strip trailing trivia from all attribute lists to prevent extra newlines - var attributeListsWithoutTrailing = strippedMethod.AttributeLists - .Select(al => al.WithTrailingTrivia()) - .ToList(); - strippedMethod = strippedMethod.WithAttributeLists( - SyntaxFactory.List(attributeListsWithoutTrailing)); - - // Now get the first modifier AFTER the replacements - var firstModifier = strippedMethod.Modifiers.FirstOrDefault(); - - // Add newline + indentation before the modifier (public, etc.) - if (firstModifier != default) - { - strippedMethod = strippedMethod.ReplaceToken( - firstModifier, - firstModifier.WithLeadingTrivia( - SyntaxFactory.CarriageReturnLineFeed, - SyntaxFactory.Whitespace(" "))); - } - - return strippedMethod.WithTrailingTrivia(SyntaxFactory.CarriageReturnLineFeed); - } - else - { - // No attributes: set trivia on first token (modifier or return type) - if (i == 0) - { - // First method: just indentation - return strippedMethod - .WithLeadingTrivia(SyntaxFactory.Whitespace(" ")) - .WithTrailingTrivia(SyntaxFactory.CarriageReturnLineFeed); - } - else - { - // Subsequent methods: blank line + indentation - return strippedMethod.WithLeadingTrivia( - SyntaxFactory.CarriageReturnLineFeed, - SyntaxFactory.Whitespace(" ") - ).WithTrailingTrivia(SyntaxFactory.CarriageReturnLineFeed); - } - } - }).ToList(); - - // New order: other members, then all normalized methods - var reorderedMembers = new List(); - reorderedMembers.AddRange(otherMembers); - reorderedMembers.AddRange(normalizedMethods); - - node = node.WithMembers(SyntaxFactory.List(reorderedMembers)); - } - - // Update base list to remove interfaces - var interfacesToRemove = new[] { "IAsyncLifetime", "IAsyncDisposable", "IDisposable" }; - var newBaseTypes = node.BaseList!.Types - .Where(x => !interfacesToRemove.Any(i => x.Type.TryGetInferredMemberName()?.EndsWith(i) == true)) - .ToList(); - - if (newBaseTypes.Count != node.BaseList.Types.Count) - { - if (newBaseTypes.Count == 0) - { - // When removing the entire base list, preserve the trailing trivia - // The base list's trailing trivia typically contains the newline before the opening brace - var baseListTrailingTrivia = node.BaseList.GetTrailingTrivia(); - - // Apply the trivia to the element before the base list (parameter list or identifier) - // REPLACE the trailing trivia rather than adding to it to avoid extra spaces - if (node.ParameterList != null) - { - node = node.WithParameterList( - node.ParameterList.WithTrailingTrivia(baseListTrailingTrivia)); - } - else - { - node = node.WithIdentifier( - node.Identifier.WithTrailingTrivia(baseListTrailingTrivia)); - } - - node = node.WithBaseList(null); - } - else - { - node = node.WithBaseList( - SyntaxFactory.BaseList(SyntaxFactory.SeparatedList(newBaseTypes)) - .WithTrailingTrivia(node.BaseList.GetTrailingTrivia())); - } - } - } - else - { - if (hasAsyncLifetime && GetInitializeMethod(node) is { } initializeMethod) - { - node = node - .ReplaceNode(initializeMethod, initializeMethod.WithReturnType(SyntaxFactory.ParseTypeName("Task").WithTrailingTrivia(SyntaxFactory.Space))); - - node = node.WithBaseList(SyntaxFactory.BaseList(SyntaxFactory.SeparatedList( - [ - ..node.BaseList!.Types.Where(x => x.Type.TryGetInferredMemberName()?.EndsWith("IAsyncLifetime") is null or false), - SyntaxFactory.SimpleBaseType(SyntaxFactory.ParseTypeName("IAsyncInitializer")) - ]))) - .WithTrailingTrivia(node.BaseList.GetTrailingTrivia()); - } - - if (hasAsyncLifetime && !hasAsyncDisposable) - { - node = node - .WithBaseList(node.BaseList!.AddTypes(SyntaxFactory.SimpleBaseType(SyntaxFactory.ParseTypeName("IAsyncDisposable")))); - } - } - - return node; - - MethodDeclarationSyntax? GetInitializeMethod(ClassDeclarationSyntax classDeclaration) - { - return classDeclaration.Members - .OfType() - .FirstOrDefault(m => m.Identifier.Text == "InitializeAsync"); - } - - MethodDeclarationSyntax? GetDisposeAsyncMethod(ClassDeclarationSyntax classDeclaration) - { - return classDeclaration.Members - .OfType() - .FirstOrDefault(m => m.Identifier.Text == "DisposeAsync"); - } - - MethodDeclarationSyntax? GetDisposeMethod(ClassDeclarationSyntax classDeclaration) - { - return classDeclaration.Members - .OfType() - .FirstOrDefault(m => m.Identifier.Text == "Dispose"); - } - } - } - - private static void UpdateSyntaxTrees(ref Compilation compilation, ref SyntaxTree syntaxTree, ref SyntaxNode updatedRoot) - { - var parseOptions = syntaxTree.Options; - var newSyntaxTree = updatedRoot.SyntaxTree; - - // If the parse options differ, re-parse the updatedRoot with the correct options - if (!Equals(newSyntaxTree.Options, parseOptions)) - { - newSyntaxTree = CSharpSyntaxTree.ParseText( - updatedRoot.ToFullString(), - (CSharpParseOptions) parseOptions, - syntaxTree.FilePath - ); - } - - compilation = compilation.ReplaceSyntaxTree(syntaxTree, newSyntaxTree); - syntaxTree = newSyntaxTree; - - updatedRoot = newSyntaxTree.GetRoot(); - } } diff --git a/TUnit.Analyzers.Tests/MSTestMigrationAnalyzerTests.cs b/TUnit.Analyzers.Tests/MSTestMigrationAnalyzerTests.cs index 143a1a1822..7d2e4a978e 100644 --- a/TUnit.Analyzers.Tests/MSTestMigrationAnalyzerTests.cs +++ b/TUnit.Analyzers.Tests/MSTestMigrationAnalyzerTests.cs @@ -1741,7 +1741,7 @@ public async Task StringTests() } [Test] - [Ignore("This test is temporarily disabled")] + [Skip("This test is temporarily disabled")] public async Task IgnoredTest() { await Assert.Fail("Should not run"); @@ -1769,6 +1769,161 @@ public async Task BooleanTest(bool value) ); } + [Test] + public async Task MSTest_Method_With_Ref_Parameter_Not_Converted_To_Async() + { + // Test that methods with ref parameters use .Wait() instead of await + await CodeFixer.VerifyCodeFixAsync( + """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + + {|#0:[TestClass]|} + public class MyClass + { + [TestMethod] + public void MyTest() + { + bool realized = false; + HandleRealized(this, ref realized); + } + + private static void HandleRealized(object sender, ref bool realized) + { + Assert.IsNotNull(sender); + realized = true; + } + } + """, + Verifier.Diagnostic(Rules.MSTestMigration).WithLocation(0), + """ + public class MyClass + { + [Test] + public void MyTest() + { + bool realized = false; + HandleRealized(this, ref realized); + } + + private static void HandleRealized(object sender, ref bool realized) + { + Assert.That(sender).IsNotNull().Wait(); + realized = true; + } + } + """, + ConfigureMSTestTest + ); + } + + [Test] + public async Task MSTest_Method_With_Out_Parameter_Not_Converted_To_Async() + { + await CodeFixer.VerifyCodeFixAsync( + """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + + {|#0:[TestClass]|} + public class MyClass + { + [TestMethod] + public void MyTest() + { + TryGetValue("key", out int value); + Assert.AreEqual(42, value); + } + + private static void TryGetValue(string key, out int value) + { + Assert.IsNotNull(key); + value = 42; + } + } + """, + Verifier.Diagnostic(Rules.MSTestMigration).WithLocation(0), + """ + using System.Threading.Tasks; + + public class MyClass + { + [Test] + public async Task MyTest() + { + TryGetValue("key", out int value); + await Assert.That(value).IsEqualTo(42); + } + + private static void TryGetValue(string key, out int value) + { + Assert.That(key).IsNotNull().Wait(); + value = 42; + } + } + """, + ConfigureMSTestTest + ); + } + + [Test] + public async Task MSTest_InterfaceImplementation_NotConvertedToAsync() + { + // Methods that implement interface members should NOT be converted to async + await CodeFixer.VerifyCodeFixAsync( + """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Threading.Tasks; + + public interface ITestRunner + { + void Run(); + } + + {|#0:[TestClass]|} + public class MyClass : ITestRunner + { + [TestMethod] + public void TestMethod() + { + Assert.IsTrue(true); + } + + public void Run() + { + // This implements ITestRunner.Run() and should stay void + var x = 1; + } + } + """, + Verifier.Diagnostic(Rules.MSTestMigration).WithLocation(0), + """ + using System.Threading.Tasks; + + public interface ITestRunner + { + void Run(); + } + public class MyClass : ITestRunner + { + [Test] + public async Task TestMethod() + { + await Assert.That(true).IsTrue(); + } + + public void Run() + { + // This implements ITestRunner.Run() and should stay void + var x = 1; + } + } + """, + ConfigureMSTestTest + ); + } + + // NOTE: MSTest lifecycle visibility changes and DoNotParallelize conversion are not implemented + // These features exist in NUnit migration but not MSTest migration + private static void ConfigureMSTestTest(Verifier.Test test) { test.TestState.AdditionalReferences.Add(typeof(TestMethodAttribute).Assembly); diff --git a/TUnit.Analyzers.Tests/NUnitMigrationAnalyzerTests.cs b/TUnit.Analyzers.Tests/NUnitMigrationAnalyzerTests.cs index 5ee49ee67f..7124e5b1cd 100644 --- a/TUnit.Analyzers.Tests/NUnitMigrationAnalyzerTests.cs +++ b/TUnit.Analyzers.Tests/NUnitMigrationAnalyzerTests.cs @@ -1555,8 +1555,8 @@ public void MyTest(int value) public class MyClass { [Test] - [Arguments(1, DisplayName = "Test One", Skip = "Temporarily disabled", Categories = ["Unit"])] - [Arguments(2, DisplayName = "Test Two", Skip = "WIP", Categories = ["Integration"])] + [Arguments(1, DisplayName = "Test One", Categories = ["Unit"], Skip = "Temporarily disabled")] + [Arguments(2, DisplayName = "Test Two", Categories = ["Integration"], Skip = "WIP")] public async Task MyTest(int value) { await Assert.That(value > 0).IsTrue(); @@ -1593,7 +1593,7 @@ public void MyTest(int value) public class MyClass { [Test] - [Arguments(1, DisplayName = "Full featured test", Skip = "Testing migration", Categories = ["Comprehensive"])] + [Arguments(1, DisplayName = "Full featured test", Categories = ["Comprehensive"], Skip = "Testing migration")] [Property("Description", "A complete test case")] [Property("Author", "Developer")] public async Task MyTest(int value) @@ -1811,10 +1811,9 @@ public async Task TestMethod() } [Test] - public async Task NUnit_ThrowsAsync_WithUnrecognizedConstraint_PreservesAction() + public async Task NUnit_ThrowsAsync_WithIsInstanceOf_Converted() { - // Test that unrecognized constraint patterns still preserve the action lambda - // This tests the fallback path in ConvertNUnitThrows + // Test that Is.InstanceOf() constraint is recognized and converted to ThrowsAsync await CodeFixer.VerifyCodeFixAsync( """ using NUnit.Framework; @@ -1825,10 +1824,10 @@ await CodeFixer.VerifyCodeFixAsync( [Test] public void TestMethod() { - // Using Is.InstanceOf which is not recognized by TryExtractTypeFromConstraint + // Is.InstanceOf() is recognized and the type is extracted Assert.ThrowsAsync(Is.InstanceOf(), async () => await SomeMethod()); } - + private async System.Threading.Tasks.Task SomeMethod() { await System.Threading.Tasks.Task.Delay(1); @@ -1846,10 +1845,10 @@ public class MyClass [Test] public async Task TestMethod() { - // Using Is.InstanceOf which is not recognized by TryExtractTypeFromConstraint - await Assert.That(async () => await SomeMethod()).Throws(); + // Is.InstanceOf() is recognized and the type is extracted + await Assert.ThrowsAsync(async () => await SomeMethod()); } - + private async System.Threading.Tasks.Task SomeMethod() { await System.Threading.Tasks.Task.Delay(1); @@ -2961,6 +2960,110 @@ private static void TryGetValue(string key, out int value) ); } + [Test] + public async Task NUnit_Method_With_Ref_Parameter_Multiple_Assertions_Uses_Wait() + { + // Multiple assertions in a method with ref parameters should all use .Wait() + await CodeFixer.VerifyCodeFixAsync( + """ + using NUnit.Framework; + + {|#0:public class MyClass|} + { + [Test] + public void MyTest() + { + int value = 0; + ProcessValue(ref value); + } + + private static void ProcessValue(ref int value) + { + Assert.That(value, Is.EqualTo(0)); + value = 42; + Assert.That(value, Is.EqualTo(42)); + Assert.That(value, Is.GreaterThan(0)); + } + } + """, + Verifier.Diagnostic(Rules.NUnitMigration).WithLocation(0), + """ + + public class MyClass + { + [Test] + public void MyTest() + { + int value = 0; + ProcessValue(ref value); + } + + private static void ProcessValue(ref int value) + { + Assert.That(value).IsEqualTo(0).Wait(); + value = 42; + Assert.That(value).IsEqualTo(42).Wait(); + Assert.That(value).IsGreaterThan(0).Wait(); + } + } + """, + ConfigureNUnitTest + ); + } + + [Test] + public async Task NUnit_AssertMultiple_Inside_Ref_Parameter_Method_Uses_Wait() + { + // Assert.Multiple inside a method with ref parameters - assertions should use .Wait() + await CodeFixer.VerifyCodeFixAsync( + """ + using NUnit.Framework; + + {|#0:public class MyClass|} + { + [Test] + public void MyTest() + { + int value = 42; + ValidateValue(ref value); + } + + private static void ValidateValue(ref int value) + { + Assert.Multiple(() => + { + Assert.That(value, Is.GreaterThan(0)); + Assert.That(value, Is.LessThan(100)); + }); + } + } + """, + Verifier.Diagnostic(Rules.NUnitMigration).WithLocation(0), + """ + + public class MyClass + { + [Test] + public void MyTest() + { + int value = 42; + ValidateValue(ref value); + } + + private static void ValidateValue(ref int value) + { + using (Assert.Multiple()) + { + Assert.That(value).IsGreaterThan(0).Wait(); + Assert.That(value).IsLessThan(100).Wait(); + } + } + } + """, + ConfigureNUnitTest + ); + } + [Test] public async Task NUnit_InterfaceImplementation_NotConvertedToAsync() { diff --git a/TUnit.Analyzers.Tests/XUnitMigrationAnalyzerTests.cs b/TUnit.Analyzers.Tests/XUnitMigrationAnalyzerTests.cs index 3d616e0ce6..b231a4dad9 100644 --- a/TUnit.Analyzers.Tests/XUnitMigrationAnalyzerTests.cs +++ b/TUnit.Analyzers.Tests/XUnitMigrationAnalyzerTests.cs @@ -98,7 +98,8 @@ public void MyTest() public class MyClass { - [Test, Skip("Reason")] + [Test] + [Skip("Reason")] public void MyTest() { } @@ -237,7 +238,8 @@ public class MyCollection : ICollectionFixture public class MyType; - [NotInParallel, ClassDataSource(Shared = SharedType.Keyed, Key = "MyCollection")] + [ClassDataSource(Shared = SharedType.Keyed, Key = "MyCollection")] + [NotInParallel] public class MyClass { [Test] @@ -439,11 +441,6 @@ public void MyTest() public class MyClass { - [Test] - public void MyTest() - { - } - [Before(Test)] public Task InitializeAsync() { @@ -455,6 +452,11 @@ public Task DisposeAsync() { return default; } + + [Test] + public void MyTest() + { + } } """, ConfigureXUnitTest @@ -1325,7 +1327,7 @@ public void ExceptionTest() [Test] public async Task AsyncExceptionTest() { - await await Assert.ThrowsAsync(async () => + await Assert.ThrowsAsync(async () => { await Task.CompletedTask; throw new ArgumentException("test"); @@ -1710,7 +1712,8 @@ public async Task StringTests() await Assert.That(str).Contains("lo Wo"); } - [Test, Skip("This test is temporarily disabled")] + [Test] + [Skip("This test is temporarily disabled")] public async Task SkippedTest() { await Assert.That(false).IsTrue().Because("Should not run"); @@ -1738,6 +1741,361 @@ public async Task BooleanTest(bool value) ); } + [Test] + public async Task XUnit_Method_With_Ref_Parameter_Not_Converted_To_Async() + { + // Test that methods with ref parameters use .Wait() instead of await + await CodeFixer.VerifyCodeFixAsync( + """ + {|#0:using Xunit; + + public class MyClass + { + [Fact] + public void MyTest() + { + bool realized = false; + HandleRealized(this, ref realized); + } + + private static void HandleRealized(object sender, ref bool realized) + { + Assert.NotNull(sender); + realized = true; + } + }|} + """, + Verifier.Diagnostic(Rules.XunitMigration).WithLocation(0), + """ + + public class MyClass + { + [Test] + public void MyTest() + { + bool realized = false; + HandleRealized(this, ref realized); + } + + private static void HandleRealized(object sender, ref bool realized) + { + Assert.That(sender).IsNotNull().Wait(); + realized = true; + } + } + """, + ConfigureXUnitTest + ); + } + + [Test] + public async Task XUnit_Method_With_Out_Parameter_Not_Converted_To_Async() + { + await CodeFixer.VerifyCodeFixAsync( + """ + {|#0:using Xunit; + + public class MyClass + { + [Fact] + public void MyTest() + { + TryGetValue("key", out int value); + Assert.Equal(42, value); + } + + private static void TryGetValue(string key, out int value) + { + Assert.NotNull(key); + value = 42; + } + }|} + """, + Verifier.Diagnostic(Rules.XunitMigration).WithLocation(0), + """ + using System.Threading.Tasks; + + public class MyClass + { + [Test] + public async Task MyTest() + { + TryGetValue("key", out int value); + await Assert.That(value).IsEqualTo(42); + } + + private static void TryGetValue(string key, out int value) + { + Assert.That(key).IsNotNull().Wait(); + value = 42; + } + } + """, + ConfigureXUnitTest + ); + } + + [Test] + public async Task XUnit_InterfaceImplementation_NotConvertedToAsync() + { + // Methods that implement interface members should NOT be converted to async + await CodeFixer.VerifyCodeFixAsync( + """ + {|#0:using Xunit; + using System.Threading.Tasks; + + public interface ITestRunner + { + void Run(); + } + + public class MyClass : ITestRunner + { + [Fact] + public void TestMethod() + { + Assert.True(true); + } + + public void Run() + { + // This implements ITestRunner.Run() and should stay void + var x = 1; + } + }|} + """, + Verifier.Diagnostic(Rules.XunitMigration).WithLocation(0), + """ + using System.Threading.Tasks; + + public interface ITestRunner + { + void Run(); + } + + public class MyClass : ITestRunner + { + [Test] + public async Task TestMethod() + { + await Assert.That(true).IsTrue(); + } + + public void Run() + { + // This implements ITestRunner.Run() and should stay void + var x = 1; + } + } + """, + ConfigureXUnitTest + ); + } + + [Test] + public async Task XUnit_Nested_Class_Converted() + { + await CodeFixer.VerifyCodeFixAsync( + """ + {|#0:using Xunit; + + public class OuterClass + { + public class InnerTests + { + [Fact] + public void InnerTest() + { + Assert.True(true); + } + } + + [Fact] + public void OuterTest() + { + Assert.False(false); + } + }|} + """, + Verifier.Diagnostic(Rules.XunitMigration).WithLocation(0), + """ + using System.Threading.Tasks; + + public class OuterClass + { + public class InnerTests + { + [Test] + public async Task InnerTest() + { + await Assert.That(true).IsTrue(); + } + } + + [Test] + public async Task OuterTest() + { + await Assert.That(false).IsFalse(); + } + } + """, + ConfigureXUnitTest + ); + } + + [Test] + public async Task XUnit_Multiple_Classes_In_File_All_Converted() + { + await CodeFixer.VerifyCodeFixAsync( + """ + {|#0:using Xunit; + + public class FirstTestClass + { + [Fact] + public void FirstTest() + { + Assert.True(true); + } + } + + public class SecondTestClass + { + [Fact] + public void SecondTest() + { + Assert.False(false); + } + } + + public class ThirdTestClass + { + [Fact] + public void ThirdTest() + { + Assert.Equal(1, 1); + } + }|} + """, + Verifier.Diagnostic(Rules.XunitMigration).WithLocation(0), + """ + using System.Threading.Tasks; + + public class FirstTestClass + { + [Test] + public async Task FirstTest() + { + await Assert.That(true).IsTrue(); + } + } + + public class SecondTestClass + { + [Test] + public async Task SecondTest() + { + await Assert.That(false).IsFalse(); + } + } + + public class ThirdTestClass + { + [Test] + public async Task ThirdTest() + { + await Assert.That(1).IsEqualTo(1); + } + } + """, + ConfigureXUnitTest + ); + } + + [Test] + public async Task XUnit_Record_Exception_DoesNotThrow_Pattern() + { + // Record.Exception returning null is equivalent to DoesNotThrow + // The migration converts Record.Exception to a try-catch pattern + await CodeFixer.VerifyCodeFixAsync( + """ + {|#0:using Xunit; + + public class MyClass + { + [Fact] + public void TestMethod() + { + int x = 1; + int y = 2; + var ex = Record.Exception(() => x += y); + Assert.Null(ex); + } + }|} + """, + Verifier.Diagnostic(Rules.XunitMigration).WithLocation(0), + """ + using System.Threading.Tasks; + + public class MyClass + { + [Test] + public async Task TestMethod() + { + int x = 1; + int y = 2; + Exception? ex = null; + try + { + x += y; + } + catch (Exception e) + { + ex = e; + } + await Assert.That(ex).IsNull(); + } + } + """, + ConfigureXUnitTest + ); + } + + [Test] + public async Task XUnit_Generic_Test_Class_Converted() + { + await CodeFixer.VerifyCodeFixAsync( + """ + {|#0:using Xunit; + + public class GenericTestClass + { + [Fact] + public void GenericTest() + { + var instance = default(T); + Assert.Equal(default(T), instance); + } + }|} + """, + Verifier.Diagnostic(Rules.XunitMigration).WithLocation(0), + """ + using System.Threading.Tasks; + + public class GenericTestClass + { + [Test] + public async Task GenericTest() + { + var instance = default(T); + await Assert.That(instance).IsEqualTo(default(T)); + } + } + """, + ConfigureXUnitTest + ); + } + private static void ConfigureXUnitTest(Verifier.Test test) { var globalUsings = ("GlobalUsings.cs", SourceText.From("global using Xunit;")); diff --git a/TUnit.Analyzers/Migrators/Base/MigrationHelpers.cs b/TUnit.Analyzers/Migrators/Base/MigrationHelpers.cs index 66250badca..9e6025db33 100644 --- a/TUnit.Analyzers/Migrators/Base/MigrationHelpers.cs +++ b/TUnit.Analyzers/Migrators/Base/MigrationHelpers.cs @@ -180,7 +180,7 @@ public static CompilationUnitSyntax RemoveFrameworkUsings(CompilationUnitSyntax { var namespacesToRemove = framework switch { - "XUnit" => new[] { "Xunit", "Xunit.Abstractions" }, + "XUnit" => new[] { "Xunit", "Xunit.Abstractions", "Xunit.Sdk" }, "NUnit" => new[] { "NUnit.Framework", "NUnit.Framework.Legacy" }, "MSTest" => new[] { "Microsoft.VisualStudio.TestTools.UnitTesting" }, _ => Array.Empty() diff --git a/TUnit.Assertions.Tests/EventAssertionTests.cs b/TUnit.Assertions.Tests/EventAssertionTests.cs new file mode 100644 index 0000000000..be804eaa2e --- /dev/null +++ b/TUnit.Assertions.Tests/EventAssertionTests.cs @@ -0,0 +1,298 @@ +using System.ComponentModel; + +namespace TUnit.Assertions.Tests; + +public class EventAssertionTests +{ + #region PropertyChanged Tests + + [Test] + public void PropertyChanged_Passes_When_Property_Changes() + { + var obj = new NotifyingClass(); + + Assert.PropertyChanged(obj, nameof(NotifyingClass.Name), () => + { + obj.Name = "New Value"; + }); + } + + [Test] + public async Task PropertyChanged_Fails_When_Property_Does_Not_Change() + { + var obj = new NotifyingClass(); + + var exception = await Assert.ThrowsAsync(() => + { + Assert.PropertyChanged(obj, nameof(NotifyingClass.Name), () => + { + // Do nothing - property doesn't change + }); + return Task.CompletedTask; + }); + + await Assert.That(exception.Message).Contains("Name"); + } + + [Test] + public async Task PropertyChanged_Fails_When_Different_Property_Changes() + { + var obj = new NotifyingClass(); + + var exception = await Assert.ThrowsAsync(() => + { + Assert.PropertyChanged(obj, nameof(NotifyingClass.Name), () => + { + obj.Age = 30; // Different property + }); + return Task.CompletedTask; + }); + + await Assert.That(exception.Message).Contains("Name"); + } + + [Test] + public async Task PropertyChangedAsync_Passes_When_Property_Changes() + { + var obj = new NotifyingClass(); + + await Assert.PropertyChangedAsync(obj, nameof(NotifyingClass.Name), async () => + { + await Task.Delay(1); + obj.Name = "New Value"; + }); + } + + [Test] + public async Task PropertyChangedAsync_Fails_When_Property_Does_Not_Change() + { + var obj = new NotifyingClass(); + + await Assert.That(async () => + { + await Assert.PropertyChangedAsync(obj, nameof(NotifyingClass.Name), async () => + { + await Task.Delay(1); + // Do nothing + }); + }).Throws(); + } + + #endregion + + #region Raises Tests + + [Test] + public async Task Raises_Passes_When_Event_Is_Raised() + { + var obj = new EventRaisingClass(); + + var result = Assert.Raises( + handler => obj.CustomEvent += handler, + handler => obj.CustomEvent -= handler, + () => obj.RaiseCustomEvent("test")); + + await Assert.That(result).IsNotNull(); + await Assert.That(result.Arguments.Value).IsEqualTo("test"); + await Assert.That(result.Sender).IsSameReferenceAs(obj); + } + + [Test] + public async Task Raises_Fails_When_Event_Is_Not_Raised() + { + var obj = new EventRaisingClass(); + + var exception = await Assert.ThrowsAsync(() => + { + Assert.Raises( + handler => obj.CustomEvent += handler, + handler => obj.CustomEvent -= handler, + () => + { + // Don't raise the event + }); + return Task.CompletedTask; + }); + + await Assert.That(exception.Message).Contains("CustomEventArgs"); + } + + [Test] + public async Task RaisesAsync_Passes_When_Event_Is_Raised() + { + var obj = new EventRaisingClass(); + + var result = await Assert.RaisesAsync( + handler => obj.CustomEvent += handler, + handler => obj.CustomEvent -= handler, + async () => + { + await Task.Delay(1); + obj.RaiseCustomEvent("async test"); + }); + + await Assert.That(result).IsNotNull(); + await Assert.That(result.Arguments.Value).IsEqualTo("async test"); + } + + [Test] + public async Task RaisesAsync_Fails_When_Event_Is_Not_Raised() + { + var obj = new EventRaisingClass(); + + await Assert.That(async () => + { + await Assert.RaisesAsync( + handler => obj.CustomEvent += handler, + handler => obj.CustomEvent -= handler, + async () => + { + await Task.Delay(1); + // Don't raise the event + }); + }).Throws(); + } + + [Test] + public async Task RaisesAny_Passes_When_Any_Event_Is_Raised() + { + var obj = new EventRaisingClass(); + + var result = Assert.RaisesAny( + handler => obj.GenericEvent += handler, + handler => obj.GenericEvent -= handler, + () => obj.RaiseGenericEvent()); + + await Assert.That(result).IsNotNull(); + await Assert.That(result.Sender).IsSameReferenceAs(obj); + } + + [Test] + public async Task RaisesAny_Passes_For_Derived_EventArgs() + { + var obj = new EventRaisingClass(); + + // RaisesAny should accept derived types + var result = Assert.RaisesAny( + handler => obj.CustomEvent += (s, e) => handler(s, e), + handler => { }, // Can't really unsubscribe this way, but it works for the test + () => obj.RaiseCustomEvent("derived")); + + await Assert.That(result).IsNotNull(); + } + + [Test] + public async Task RaisesAny_Fails_When_No_Event_Is_Raised() + { + var obj = new EventRaisingClass(); + + var exception = await Assert.ThrowsAsync(() => + { + Assert.RaisesAny( + handler => obj.GenericEvent += handler, + handler => obj.GenericEvent -= handler, + () => + { + // Don't raise the event + }); + return Task.CompletedTask; + }); + + await Assert.That(exception.Message).Contains("EventArgs"); + } + + [Test] + public async Task RaisesAnyAsync_Passes_When_Event_Is_Raised() + { + var obj = new EventRaisingClass(); + + var result = await Assert.RaisesAnyAsync( + handler => obj.GenericEvent += handler, + handler => obj.GenericEvent -= handler, + async () => + { + await Task.Delay(1); + obj.RaiseGenericEvent(); + }); + + await Assert.That(result).IsNotNull(); + } + + [Test] + public async Task RaisesAnyAsync_Fails_When_No_Event_Is_Raised() + { + var obj = new EventRaisingClass(); + + await Assert.That(async () => + { + await Assert.RaisesAnyAsync( + handler => obj.GenericEvent += handler, + handler => obj.GenericEvent -= handler, + async () => + { + await Task.Delay(1); + // Don't raise the event + }); + }).Throws(); + } + + #endregion + + #region Helper Classes + + private class NotifyingClass : INotifyPropertyChanged + { + private string? _name; + private int _age; + + public event PropertyChangedEventHandler? PropertyChanged; + + public string? Name + { + get => _name; + set + { + _name = value; + PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(nameof(Name))); + } + } + + public int Age + { + get => _age; + set + { + _age = value; + PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(nameof(Age))); + } + } + } + + private class CustomEventArgs : EventArgs + { + public string Value { get; } + + public CustomEventArgs(string value) + { + Value = value; + } + } + + private class EventRaisingClass + { + public event EventHandler? CustomEvent; + public event EventHandler? GenericEvent; + + public void RaiseCustomEvent(string value) + { + CustomEvent?.Invoke(this, new CustomEventArgs(value)); + } + + public void RaiseGenericEvent() + { + GenericEvent?.Invoke(this, EventArgs.Empty); + } + } + + #endregion +} diff --git a/TUnit.Assertions.Tests/StrictEqualityTests.cs b/TUnit.Assertions.Tests/StrictEqualityTests.cs new file mode 100644 index 0000000000..bcce44aaf5 --- /dev/null +++ b/TUnit.Assertions.Tests/StrictEqualityTests.cs @@ -0,0 +1,128 @@ +namespace TUnit.Assertions.Tests; + +public class StrictEqualityTests +{ + [Test] + public async Task IsStrictlyEqualTo_Passes_For_Equal_Values() + { + var value = 42; + await Assert.That(value).IsStrictlyEqualTo(42); + } + + [Test] + public async Task IsStrictlyEqualTo_Passes_For_Equal_Strings() + { + var value = "hello"; + await Assert.That(value).IsStrictlyEqualTo("hello"); + } + + [Test] + public async Task IsStrictlyEqualTo_Passes_For_Same_Reference() + { + var obj = new object(); + await Assert.That(obj).IsStrictlyEqualTo(obj); + } + + [Test] + public async Task IsStrictlyEqualTo_Fails_For_Different_Values() + { + var value = 42; + await Assert.That(async () => await Assert.That(value).IsStrictlyEqualTo(43)) + .Throws(); + } + + [Test] + public async Task IsStrictlyEqualTo_Fails_For_Different_References() + { + var obj1 = new NoEqualsOverride(); + var obj2 = new NoEqualsOverride(); + + // Without Equals override, different instances are not equal + await Assert.That(async () => await Assert.That(obj1).IsStrictlyEqualTo(obj2)) + .Throws(); + } + + [Test] + public async Task IsStrictlyEqualTo_Uses_Object_Equals_Not_IEquatable() + { + // CustomEquatable uses IEquatable to say all instances are equal + // But object.Equals is NOT overridden, so strict equality should fail + var obj1 = new CustomEquatable(1); + var obj2 = new CustomEquatable(2); + + // Standard equality would pass (IEquatable says they're equal) + await Assert.That(obj1).IsEqualTo(obj2); + + // Strict equality should fail (object.Equals returns false for different refs) + await Assert.That(async () => await Assert.That(obj1).IsStrictlyEqualTo(obj2)) + .Throws(); + } + + [Test] + public async Task IsNotStrictlyEqualTo_Passes_For_Different_Values() + { + var value = 42; + await Assert.That(value).IsNotStrictlyEqualTo(43); + } + + [Test] + public async Task IsNotStrictlyEqualTo_Passes_For_Different_References() + { + var obj1 = new NoEqualsOverride(); + var obj2 = new NoEqualsOverride(); + await Assert.That(obj1).IsNotStrictlyEqualTo(obj2); + } + + [Test] + public async Task IsNotStrictlyEqualTo_Fails_For_Equal_Values() + { + var value = 42; + await Assert.That(async () => await Assert.That(value).IsNotStrictlyEqualTo(42)) + .Throws(); + } + + [Test] + public async Task IsNotStrictlyEqualTo_Fails_For_Same_Reference() + { + var obj = new object(); + await Assert.That(async () => await Assert.That(obj).IsNotStrictlyEqualTo(obj)) + .Throws(); + } + + [Test] + public async Task IsStrictlyEqualTo_Handles_Null_Values() + { + string? value = null; + await Assert.That(value).IsStrictlyEqualTo(null); + } + + [Test] + public async Task IsStrictlyEqualTo_Fails_When_Only_One_Is_Null() + { + string? value = "hello"; + await Assert.That(async () => await Assert.That(value).IsStrictlyEqualTo(null)) + .Throws(); + } + + // Helper class that does NOT override Equals + private class NoEqualsOverride + { + } + + // Helper class that implements IEquatable but doesn't override object.Equals + private class CustomEquatable : IEquatable + { + public int Value { get; } + + public CustomEquatable(int value) + { + Value = value; + } + + // IEquatable says all instances are equal + public bool Equals(CustomEquatable? other) => true; + + // Deliberately NOT overriding object.Equals + // So object.Equals(a, b) uses reference equality + } +} diff --git a/TUnit.Assertions/Assertions/StrictEqualityAssertions.cs b/TUnit.Assertions/Assertions/StrictEqualityAssertions.cs new file mode 100644 index 0000000000..deff2622db --- /dev/null +++ b/TUnit.Assertions/Assertions/StrictEqualityAssertions.cs @@ -0,0 +1,25 @@ +using TUnit.Assertions.Attributes; + +namespace TUnit.Assertions.Assertions; + +/// +/// Strict equality assertions that use +/// instead of . +/// This is useful when you want to compare without using IEquatable<T> implementations. +/// +file static class StrictEqualityAssertions +{ + /// + /// Asserts that the value is strictly equal to the expected value using . + /// Unlike IsEqualTo, this does not use or custom equality comparers. + /// + [GenerateAssertion(InlineMethodBody = true, ExpectationMessage = "be strictly equal to {expected}")] + public static bool IsStrictlyEqualTo(this T value, T expected) => object.Equals(value, expected); + + /// + /// Asserts that the value is not strictly equal to the expected value using . + /// Unlike IsNotEqualTo, this does not use or custom equality comparers. + /// + [GenerateAssertion(InlineMethodBody = true, ExpectationMessage = "not be strictly equal to {expected}")] + public static bool IsNotStrictlyEqualTo(this T value, T expected) => !object.Equals(value, expected); +} diff --git a/TUnit.Assertions/Extensions/Assert.cs b/TUnit.Assertions/Extensions/Assert.cs index 32e56ece9a..f4c9ba9ff4 100644 --- a/TUnit.Assertions/Extensions/Assert.cs +++ b/TUnit.Assertions/Extensions/Assert.cs @@ -1,4 +1,5 @@ using System.Collections; +using System.ComponentModel; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using TUnit.Assertions.Conditions; @@ -643,4 +644,234 @@ public static void Null( throw new AssertionException($"Expected {expression ?? "value"} to be null, but it was {value.Value}"); } } + + /// + /// Asserts that an object raises the + /// event for the specified property when the action is executed. + /// Example: Assert.PropertyChanged(viewModel, "Name", () => viewModel.Name = "new value"); + /// + /// The object implementing INotifyPropertyChanged to monitor + /// The name of the property expected to change + /// The action that should trigger the PropertyChanged event + /// Thrown if the PropertyChanged event is not raised for the specified property + public static void PropertyChanged( + INotifyPropertyChanged @object, + string propertyName, + Action action) + { + var propertyChanged = false; + + void Handler(object? sender, PropertyChangedEventArgs e) + { + if (e.PropertyName == propertyName) + { + propertyChanged = true; + } + } + + @object.PropertyChanged += Handler; + try + { + action(); + } + finally + { + @object.PropertyChanged -= Handler; + } + + if (!propertyChanged) + { + throw new AssertionException($"Expected PropertyChanged event for property '{propertyName}' but it was not raised"); + } + } + + /// + /// Asserts that an object raises the + /// event for the specified property when the async action is executed. + /// Example: await Assert.PropertyChangedAsync(viewModel, "Name", async () => await viewModel.UpdateNameAsync("new value")); + /// + /// The object implementing INotifyPropertyChanged to monitor + /// The name of the property expected to change + /// The async action that should trigger the PropertyChanged event + /// Thrown if the PropertyChanged event is not raised for the specified property + public static async Task PropertyChangedAsync( + INotifyPropertyChanged @object, + string propertyName, + Func action) + { + var propertyChanged = false; + + void Handler(object? sender, PropertyChangedEventArgs e) + { + if (e.PropertyName == propertyName) + { + propertyChanged = true; + } + } + + @object.PropertyChanged += Handler; + try + { + await action().ConfigureAwait(false); + } + finally + { + @object.PropertyChanged -= Handler; + } + + if (!propertyChanged) + { + throw new AssertionException($"Expected PropertyChanged event for property '{propertyName}' but it was not raised"); + } + } + + /// + /// Asserts that an event is raised when the action is executed and returns information about the raised event. + /// Example: var raised = Assert.Raises<EventArgs>(h => obj.Event += h, h => obj.Event -= h, () => obj.TriggerEvent()); + /// + /// The type of the event arguments + /// Action to attach the event handler + /// Action to detach the event handler + /// The action that should trigger the event + /// A containing the sender and event arguments + /// Thrown if the event is not raised + public static RaisedEvent Raises( + Action> attach, + Action> detach, + Action action) + where T : EventArgs + { + RaisedEvent? result = null; + + void Handler(object? sender, T args) + { + result = new RaisedEvent(sender, args); + } + + attach(Handler); + try + { + action(); + } + finally + { + detach(Handler); + } + + if (result == null) + { + throw new AssertionException($"Expected event of type {typeof(T).Name} to be raised but it was not"); + } + + return result; + } + + /// + /// Asserts that an event is raised when the async action is executed and returns information about the raised event. + /// Example: var raised = await Assert.RaisesAsync<EventArgs>(h => obj.Event += h, h => obj.Event -= h, async () => await obj.TriggerEventAsync()); + /// + /// The type of the event arguments + /// Action to attach the event handler + /// Action to detach the event handler + /// The async action that should trigger the event + /// A containing the sender and event arguments + /// Thrown if the event is not raised + public static async Task> RaisesAsync( + Action> attach, + Action> detach, + Func action) + where T : EventArgs + { + RaisedEvent? result = null; + + void Handler(object? sender, T args) + { + result = new RaisedEvent(sender, args); + } + + attach(Handler); + try + { + await action().ConfigureAwait(false); + } + finally + { + detach(Handler); + } + + if (result == null) + { + throw new AssertionException($"Expected event of type {typeof(T).Name} to be raised but it was not"); + } + + return result; + } + + /// + /// Asserts that an event is raised (with any event args matching the constraint) when the action is executed. + /// Example: var raised = Assert.RaisesAny<EventArgs>(h => obj.Event += h, h => obj.Event -= h, () => obj.TriggerEvent()); + /// + /// The base type of the event arguments (will match any derived type) + /// Action to attach the event handler + /// Action to detach the event handler + /// The action that should trigger the event + /// A containing the sender and event arguments + /// Thrown if the event is not raised + public static RaisedEvent RaisesAny( + Action> attach, + Action> detach, + Action action) + where T : EventArgs + { + // RaisesAny has the same implementation as Raises but the semantic is that it matches any derived type + return Raises(attach, detach, action); + } + + /// + /// Asserts that an event is raised (with any event args matching the constraint) when the async action is executed. + /// Example: var raised = await Assert.RaisesAnyAsync<EventArgs>(h => obj.Event += h, h => obj.Event -= h, async () => await obj.TriggerEventAsync()); + /// + /// The base type of the event arguments (will match any derived type) + /// Action to attach the event handler + /// Action to detach the event handler + /// The async action that should trigger the event + /// A containing the sender and event arguments + /// Thrown if the event is not raised + public static Task> RaisesAnyAsync( + Action> attach, + Action> detach, + Func action) + where T : EventArgs + { + // RaisesAnyAsync has the same implementation as RaisesAsync but the semantic is that it matches any derived type + return RaisesAsync(attach, detach, action); + } +} + +/// +/// Represents the result of a raised event, containing the sender and event arguments. +/// +/// The type of the event arguments +public class RaisedEvent where T : EventArgs +{ + /// + /// Gets the sender of the event. + /// + public object? Sender { get; } + + /// + /// Gets the event arguments. + /// + public T Arguments { get; } + + /// + /// Creates a new instance of . + /// + /// The sender of the event + /// The event arguments + public RaisedEvent(object? sender, T arguments) + { + Sender = sender; + Arguments = arguments; + } } diff --git a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet10_0.verified.txt b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet10_0.verified.txt index 90ec203075..e40c19cffa 100644 --- a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet10_0.verified.txt +++ b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet10_0.verified.txt @@ -174,6 +174,16 @@ namespace where T : class { } public static void Null(T? value, [.("value")] string? expression = null) where T : struct { } + public static void PropertyChanged(.INotifyPropertyChanged @object, string propertyName, action) { } + public static . PropertyChangedAsync(.INotifyPropertyChanged @object, string propertyName, <.> action) { } + public static .RaisedEvent Raises(<> attach, <> detach, action) + where T : { } + public static .RaisedEvent RaisesAny(<> attach, <> detach, action) + where T : { } + public static .<.RaisedEvent> RaisesAnyAsync(<> attach, <> detach, <.> action) + where T : { } + public static .<.RaisedEvent> RaisesAsync(<> attach, <> detach, <.> action) + where T : { } public static . That( action, [.("action")] string? expression = null) { } public static . That(.IEnumerable value, [.("value")] string? expression = null) { } public static . That(<.> action, [.("action")] string? expression = null) { } @@ -246,6 +256,13 @@ namespace public static void Unless([.(false)] bool condition, string reason) { } public static void When([.(true)] bool condition, string reason) { } } + public class RaisedEvent + where T : + { + public RaisedEvent(object? sender, T arguments) { } + public T Arguments { get; } + public object? Sender { get; } + } public class StringMatcher { public .StringMatcher IgnoringCase() { } @@ -4575,6 +4592,13 @@ namespace .Extensions protected override .<.> CheckAsync(.<.Stream> metadata) { } protected override string GetExpectation() { } } + public static class StrictEqualityAssertions + { + [.("Trimming", "IL2091", Justification="Generic type parameter is only used for property access, not instantiation")] + public static .Extensions.T_IsNotStrictlyEqualTo_T_Assertion IsNotStrictlyEqualTo(this . source, T expected, [.("expected")] string? expectedExpression = null) { } + [.("Trimming", "IL2091", Justification="Generic type parameter is only used for property access, not instantiation")] + public static .Extensions.T_IsStrictlyEqualTo_T_Assertion IsStrictlyEqualTo(this . source, T expected, [.("expected")] string? expectedExpression = null) { } + } public static class StringBuilderAssertionExtensions { public static ._HasExcessCapacity_Assertion HasExcessCapacity(this .<.StringBuilder> source) { } @@ -4718,6 +4742,20 @@ namespace .Extensions protected override .<.> CheckAsync(. metadata) { } protected override string GetExpectation() { } } + [.("Trimming", "IL2091", Justification="Generic type parameter is only used for property access, not instantiation")] + public sealed class T_IsNotStrictlyEqualTo_T_Assertion : . + { + public T_IsNotStrictlyEqualTo_T_Assertion(. context, T expected) { } + protected override .<.> CheckAsync(. metadata) { } + protected override string GetExpectation() { } + } + [.("Trimming", "IL2091", Justification="Generic type parameter is only used for property access, not instantiation")] + public sealed class T_IsStrictlyEqualTo_T_Assertion : . + { + public T_IsStrictlyEqualTo_T_Assertion(. context, T expected) { } + protected override .<.> CheckAsync(. metadata) { } + protected override string GetExpectation() { } + } public static class TaskAssertionExtensions { public static . IsCanceled(this . source) diff --git a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet8_0.verified.txt b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet8_0.verified.txt index f7856034e6..f14cdd899b 100644 --- a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet8_0.verified.txt +++ b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet8_0.verified.txt @@ -174,6 +174,16 @@ namespace where T : class { } public static void Null(T? value, [.("value")] string? expression = null) where T : struct { } + public static void PropertyChanged(.INotifyPropertyChanged @object, string propertyName, action) { } + public static . PropertyChangedAsync(.INotifyPropertyChanged @object, string propertyName, <.> action) { } + public static .RaisedEvent Raises(<> attach, <> detach, action) + where T : { } + public static .RaisedEvent RaisesAny(<> attach, <> detach, action) + where T : { } + public static .<.RaisedEvent> RaisesAnyAsync(<> attach, <> detach, <.> action) + where T : { } + public static .<.RaisedEvent> RaisesAsync(<> attach, <> detach, <.> action) + where T : { } public static . That( action, [.("action")] string? expression = null) { } public static . That(.IEnumerable value, [.("value")] string? expression = null) { } public static . That(<.> action, [.("action")] string? expression = null) { } @@ -229,6 +239,13 @@ namespace public static void Unless([.(false)] bool condition, string reason) { } public static void When([.(true)] bool condition, string reason) { } } + public class RaisedEvent + where T : + { + public RaisedEvent(object? sender, T arguments) { } + public T Arguments { get; } + public object? Sender { get; } + } public class StringMatcher { public .StringMatcher IgnoringCase() { } @@ -4525,6 +4542,13 @@ namespace .Extensions protected override .<.> CheckAsync(.<.Stream> metadata) { } protected override string GetExpectation() { } } + public static class StrictEqualityAssertions + { + [.("Trimming", "IL2091", Justification="Generic type parameter is only used for property access, not instantiation")] + public static .Extensions.T_IsNotStrictlyEqualTo_T_Assertion IsNotStrictlyEqualTo(this . source, T expected, [.("expected")] string? expectedExpression = null) { } + [.("Trimming", "IL2091", Justification="Generic type parameter is only used for property access, not instantiation")] + public static .Extensions.T_IsStrictlyEqualTo_T_Assertion IsStrictlyEqualTo(this . source, T expected, [.("expected")] string? expectedExpression = null) { } + } public static class StringBuilderAssertionExtensions { public static ._HasExcessCapacity_Assertion HasExcessCapacity(this .<.StringBuilder> source) { } @@ -4668,6 +4692,20 @@ namespace .Extensions protected override .<.> CheckAsync(. metadata) { } protected override string GetExpectation() { } } + [.("Trimming", "IL2091", Justification="Generic type parameter is only used for property access, not instantiation")] + public sealed class T_IsNotStrictlyEqualTo_T_Assertion : . + { + public T_IsNotStrictlyEqualTo_T_Assertion(. context, T expected) { } + protected override .<.> CheckAsync(. metadata) { } + protected override string GetExpectation() { } + } + [.("Trimming", "IL2091", Justification="Generic type parameter is only used for property access, not instantiation")] + public sealed class T_IsStrictlyEqualTo_T_Assertion : . + { + public T_IsStrictlyEqualTo_T_Assertion(. context, T expected) { } + protected override .<.> CheckAsync(. metadata) { } + protected override string GetExpectation() { } + } public static class TaskAssertionExtensions { public static . IsCanceled(this . source) diff --git a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet9_0.verified.txt b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet9_0.verified.txt index 4de5b374e4..656be3aba0 100644 --- a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet9_0.verified.txt +++ b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.DotNet9_0.verified.txt @@ -174,6 +174,16 @@ namespace where T : class { } public static void Null(T? value, [.("value")] string? expression = null) where T : struct { } + public static void PropertyChanged(.INotifyPropertyChanged @object, string propertyName, action) { } + public static . PropertyChangedAsync(.INotifyPropertyChanged @object, string propertyName, <.> action) { } + public static .RaisedEvent Raises(<> attach, <> detach, action) + where T : { } + public static .RaisedEvent RaisesAny(<> attach, <> detach, action) + where T : { } + public static .<.RaisedEvent> RaisesAnyAsync(<> attach, <> detach, <.> action) + where T : { } + public static .<.RaisedEvent> RaisesAsync(<> attach, <> detach, <.> action) + where T : { } public static . That( action, [.("action")] string? expression = null) { } public static . That(.IEnumerable value, [.("value")] string? expression = null) { } public static . That(<.> action, [.("action")] string? expression = null) { } @@ -246,6 +256,13 @@ namespace public static void Unless([.(false)] bool condition, string reason) { } public static void When([.(true)] bool condition, string reason) { } } + public class RaisedEvent + where T : + { + public RaisedEvent(object? sender, T arguments) { } + public T Arguments { get; } + public object? Sender { get; } + } public class StringMatcher { public .StringMatcher IgnoringCase() { } @@ -4575,6 +4592,13 @@ namespace .Extensions protected override .<.> CheckAsync(.<.Stream> metadata) { } protected override string GetExpectation() { } } + public static class StrictEqualityAssertions + { + [.("Trimming", "IL2091", Justification="Generic type parameter is only used for property access, not instantiation")] + public static .Extensions.T_IsNotStrictlyEqualTo_T_Assertion IsNotStrictlyEqualTo(this . source, T expected, [.("expected")] string? expectedExpression = null) { } + [.("Trimming", "IL2091", Justification="Generic type parameter is only used for property access, not instantiation")] + public static .Extensions.T_IsStrictlyEqualTo_T_Assertion IsStrictlyEqualTo(this . source, T expected, [.("expected")] string? expectedExpression = null) { } + } public static class StringBuilderAssertionExtensions { public static ._HasExcessCapacity_Assertion HasExcessCapacity(this .<.StringBuilder> source) { } @@ -4718,6 +4742,20 @@ namespace .Extensions protected override .<.> CheckAsync(. metadata) { } protected override string GetExpectation() { } } + [.("Trimming", "IL2091", Justification="Generic type parameter is only used for property access, not instantiation")] + public sealed class T_IsNotStrictlyEqualTo_T_Assertion : . + { + public T_IsNotStrictlyEqualTo_T_Assertion(. context, T expected) { } + protected override .<.> CheckAsync(. metadata) { } + protected override string GetExpectation() { } + } + [.("Trimming", "IL2091", Justification="Generic type parameter is only used for property access, not instantiation")] + public sealed class T_IsStrictlyEqualTo_T_Assertion : . + { + public T_IsStrictlyEqualTo_T_Assertion(. context, T expected) { } + protected override .<.> CheckAsync(. metadata) { } + protected override string GetExpectation() { } + } public static class TaskAssertionExtensions { public static . IsCanceled(this . source) diff --git a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.Net4_7.verified.txt b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.Net4_7.verified.txt index 1571fc906f..5af184919a 100644 --- a/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.Net4_7.verified.txt +++ b/TUnit.PublicAPI/Tests.Assertions_Library_Has_No_API_Changes.Net4_7.verified.txt @@ -137,6 +137,16 @@ namespace where T : class { } public static void Null(T? value, [.("value")] string? expression = null) where T : struct { } + public static void PropertyChanged(.INotifyPropertyChanged @object, string propertyName, action) { } + public static . PropertyChangedAsync(.INotifyPropertyChanged @object, string propertyName, <.> action) { } + public static .RaisedEvent Raises(<> attach, <> detach, action) + where T : { } + public static .RaisedEvent RaisesAny(<> attach, <> detach, action) + where T : { } + public static .<.RaisedEvent> RaisesAnyAsync(<> attach, <> detach, <.> action) + where T : { } + public static .<.RaisedEvent> RaisesAsync(<> attach, <> detach, <.> action) + where T : { } public static . That( action, [.("action")] string? expression = null) { } public static . That(.IEnumerable value, [.("value")] string? expression = null) { } public static . That(<.> action, [.("action")] string? expression = null) { } @@ -188,6 +198,13 @@ namespace public static void Unless([.(false)] bool condition, string reason) { } public static void When([.(true)] bool condition, string reason) { } } + public class RaisedEvent + where T : + { + public RaisedEvent(object? sender, T arguments) { } + public T Arguments { get; } + public object? Sender { get; } + } public class StringMatcher { public .StringMatcher IgnoringCase() { } @@ -3952,6 +3969,11 @@ namespace .Extensions protected override .<.> CheckAsync(.<.Stream> metadata) { } protected override string GetExpectation() { } } + public static class StrictEqualityAssertions + { + public static .Extensions.T_IsNotStrictlyEqualTo_T_Assertion IsNotStrictlyEqualTo(this . source, T expected, [.("expected")] string? expectedExpression = null) { } + public static .Extensions.T_IsStrictlyEqualTo_T_Assertion IsStrictlyEqualTo(this . source, T expected, [.("expected")] string? expectedExpression = null) { } + } public static class StringBuilderAssertionExtensions { public static ._HasExcessCapacity_Assertion HasExcessCapacity(this .<.StringBuilder> source) { } @@ -4091,6 +4113,18 @@ namespace .Extensions protected override .<.> CheckAsync(. metadata) { } protected override string GetExpectation() { } } + public sealed class T_IsNotStrictlyEqualTo_T_Assertion : . + { + public T_IsNotStrictlyEqualTo_T_Assertion(. context, T expected) { } + protected override .<.> CheckAsync(. metadata) { } + protected override string GetExpectation() { } + } + public sealed class T_IsStrictlyEqualTo_T_Assertion : . + { + public T_IsStrictlyEqualTo_T_Assertion(. context, T expected) { } + protected override .<.> CheckAsync(. metadata) { } + protected override string GetExpectation() { } + } public static class TaskAssertionExtensions { public static . IsCanceled(this . source)