diff --git a/TUnit.Analyzers.Tests/DependsOnConflictAnalyzerTests.cs b/TUnit.Analyzers.Tests/DependsOnConflictAnalyzerTests.cs index dd1859f1db..1d9ea89a63 100644 --- a/TUnit.Analyzers.Tests/DependsOnConflictAnalyzerTests.cs +++ b/TUnit.Analyzers.Tests/DependsOnConflictAnalyzerTests.cs @@ -80,10 +80,20 @@ public class MyClass2 .WithMessage("DependsOn Conflicts: MyClass1.Test > MyClass2.Test > MyClass1.Test") .WithLocation(0), + Verifier + .Diagnostic(Rules.DependsOnConflicts) + .WithMessage("DependsOn Conflicts: MyClass1.Test2 > MyClass2.Test > MyClass1.Test > MyClass2.Test2 > MyClass1.Test2") + .WithLocation(1), + Verifier .Diagnostic(Rules.DependsOnConflicts) .WithMessage("DependsOn Conflicts: MyClass2.Test > MyClass1.Test > MyClass2.Test") - .WithLocation(2) + .WithLocation(2), + + Verifier + .Diagnostic(Rules.DependsOnConflicts) + .WithMessage("DependsOn Conflicts: MyClass2.Test2 > MyClass1.Test > MyClass2.Test > MyClass1.Test2 > MyClass2.Test2") + .WithLocation(3) ); } @@ -129,10 +139,20 @@ public class MyClass2 .WithMessage("DependsOn Conflicts: MyClass1.Test > MyClass2.Test > MyClass1.Test") .WithLocation(0), + Verifier + .Diagnostic(Rules.DependsOnConflicts) + .WithMessage("DependsOn Conflicts: MyClass1.Test2 > MyClass2.Test > MyClass1.Test > MyClass2.Test2 > MyClass1.Test2") + .WithLocation(1), + Verifier .Diagnostic(Rules.DependsOnConflicts) .WithMessage("DependsOn Conflicts: MyClass2.Test > MyClass1.Test > MyClass2.Test") - .WithLocation(2) + .WithLocation(2), + + Verifier + .Diagnostic(Rules.DependsOnConflicts) + .WithMessage("DependsOn Conflicts: MyClass2.Test2 > MyClass1.Test > MyClass2.Test > MyClass1.Test2 > MyClass2.Test2") + .WithLocation(3) ); } diff --git a/TUnit.Analyzers/DependsOnConflictAnalyzer.cs b/TUnit.Analyzers/DependsOnConflictAnalyzer.cs index e30f22ce48..69c3e43570 100644 --- a/TUnit.Analyzers/DependsOnConflictAnalyzer.cs +++ b/TUnit.Analyzers/DependsOnConflictAnalyzer.cs @@ -1,4 +1,4 @@ -using System.Collections.Immutable; +using System.Collections.Immutable; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.Diagnostics; using TUnit.Analyzers.Extensions; @@ -6,27 +6,6 @@ namespace TUnit.Analyzers; -public record Chain(IMethodSymbol OriginalMethod) -{ - public List Dependencies { get; } = new List(); - - public bool MethodTraversed(IMethodSymbol method) => Dependencies.Contains(method, SymbolEqualityComparer.Default); - - public bool Any() => Dependencies.Any(); - - public void Add(IMethodSymbol dependency) - { - Dependencies.Add(dependency); - } - - public IMethodSymbol[] GetCompleteChain() - { - return new[] { OriginalMethod } - .Concat(Dependencies.TakeUntil(d => SymbolEqualityComparer.Default.Equals(d, OriginalMethod))) - .ToArray(); - } -} - [DiagnosticAnalyzer(LanguageNames.CSharp)] public class DependsOnConflictAnalyzer : ConcurrentDiagnosticAnalyzer { @@ -44,16 +23,23 @@ private void AnalyzeSymbol(SymbolAnalysisContext context) var dependsOnAttributes = GetDependsOnAttributes(method).Concat(GetDependsOnAttributes(method.ReceiverType ?? method.ContainingType)).ToArray(); - var dependencies = GetDependencies(context, new Chain(method), method, dependsOnAttributes); - - if (!dependencies.Any() || !dependencies.MethodTraversed(method)) + if (dependsOnAttributes.Length == 0) { return; } - context.ReportDiagnostic(Diagnostic.Create(Rules.DependsOnConflicts, - method.Locations.FirstOrDefault(), - string.Join(" > ", dependencies.GetCompleteChain().Select(x => $"{(x.ReceiverType ?? x.ContainingType).Name}.{x.Name}")))); + var cyclePath = new List(); + var visited = new HashSet(SymbolEqualityComparer.Default); + + if (FindCycleBackTo(context, method, method, dependsOnAttributes, cyclePath, visited)) + { + var chainDescription = string.Join(" > ", + new[] { method }.Concat(cyclePath).Concat(new[] { method }) + .Select(x => $"{(x.ReceiverType ?? x.ContainingType).Name}.{x.Name}")); + + context.ReportDiagnostic(Diagnostic.Create(Rules.DependsOnConflicts, + method.Locations.FirstOrDefault(), chainDescription)); + } } private AttributeData[] GetDependsOnAttributes(ISymbol methodSymbol) @@ -65,22 +51,27 @@ private AttributeData[] GetDependsOnAttributes(ISymbol methodSymbol) .ToArray(); } - private Chain GetDependencies(SymbolAnalysisContext context, Chain chain, - IMethodSymbol methodToGetDependenciesFor, AttributeData[] dependsOnAttributes) + /// + /// Performs a DFS from looking for a path back to . + /// Returns true if a cycle is found, with containing the intermediate methods. + /// + private bool FindCycleBackTo(SymbolAnalysisContext context, IMethodSymbol targetMethod, + IMethodSymbol currentMethod, AttributeData[] dependsOnAttributes, + List path, HashSet visited) { - if (!methodToGetDependenciesFor.IsTestMethod(context.Compilation)) + if (!currentMethod.IsTestMethod(context.Compilation)) { - return chain; + return false; } - if (!dependsOnAttributes.Any()) + if (dependsOnAttributes.Length == 0) { - return chain; + return false; } foreach (var dependsOnAttribute in dependsOnAttributes) { - var dependencyType = GetTypeContainingMethod(methodToGetDependenciesFor, dependsOnAttribute); + var dependencyType = GetTypeContainingMethod(currentMethod, dependsOnAttribute); var dependencyMethodName = dependsOnAttribute.ConstructorArguments .FirstOrNull(x => x.Kind == TypedConstantKind.Primitive)?.Value as string; @@ -93,7 +84,7 @@ private Chain GetDependencies(SymbolAnalysisContext context, Chain chain, if (dependencyType is not INamedTypeSymbol namedTypeSymbol) { - return chain; + continue; } var methods = namedTypeSymbol @@ -103,47 +94,50 @@ private Chain GetDependencies(SymbolAnalysisContext context, Chain chain, .Where(x => x.MethodKind == MethodKind.Ordinary) .ToArray(); - if (!methods.Any()) + if (methods.Length == 0) { context.ReportDiagnostic(Diagnostic.Create(Rules.NoMethodFound, dependsOnAttribute.GetLocation())); - - return chain; + continue; } var foundDependencies = FilterMethods(dependencyMethodName, methods, dependencyParameterTypes); - if (!foundDependencies.Any()) + if (foundDependencies.Length == 0) { context.ReportDiagnostic(Diagnostic.Create(Rules.NoMethodFound, dependsOnAttribute.GetLocation())); - return chain; + continue; } foreach (var foundDependency in foundDependencies) { - if (chain.MethodTraversed(foundDependency)) + // Found a cycle back to the target method + if (SymbolEqualityComparer.Default.Equals(foundDependency, targetMethod)) { - chain.Add(foundDependency); - return chain; + return true; } - chain.Add(foundDependency); + // Skip already-visited methods to avoid infinite recursion + if (!visited.Add(foundDependency)) + { + continue; + } - var nestedChain = GetDependencies(context, chain, foundDependency, GetDependsOnAttributes(foundDependency).Concat(GetDependsOnAttributes(foundDependency.ReceiverType ?? foundDependency.ContainingType)).ToArray()); + path.Add(foundDependency); - foreach (var nestedDependency in nestedChain.Dependencies) - { - if (chain.MethodTraversed(nestedDependency)) - { - chain.Add(nestedDependency); - return chain; - } + var nestedAttributes = GetDependsOnAttributes(foundDependency) + .Concat(GetDependsOnAttributes(foundDependency.ReceiverType ?? foundDependency.ContainingType)) + .ToArray(); - chain.Add(nestedDependency); + if (FindCycleBackTo(context, targetMethod, foundDependency, nestedAttributes, path, visited)) + { + return true; } + + path.RemoveAt(path.Count - 1); } } - return chain; + return false; } private static ITypeSymbol GetTypeContainingMethod(IMethodSymbol methodToGetDependenciesFor, AttributeData dependsOnAttribute)