Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions TUnit.Analyzers.Tests/DependsOnConflictAnalyzerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,20 @@ public class MyClass2
.WithMessage("DependsOn Conflicts: MyClass1.Test > MyClass2.Test > MyClass1.Test")
.WithLocation(0),

Verifier
.Diagnostic(Rules.DependsOnConflicts)
.WithMessage("DependsOn Conflicts: MyClass1.Test2 > MyClass2.Test > MyClass1.Test > MyClass2.Test2 > MyClass1.Test2")
.WithLocation(1),

Verifier
.Diagnostic(Rules.DependsOnConflicts)
.WithMessage("DependsOn Conflicts: MyClass2.Test > MyClass1.Test > MyClass2.Test")
.WithLocation(2)
.WithLocation(2),

Verifier
.Diagnostic(Rules.DependsOnConflicts)
.WithMessage("DependsOn Conflicts: MyClass2.Test2 > MyClass1.Test > MyClass2.Test > MyClass1.Test2 > MyClass2.Test2")
.WithLocation(3)
);
}

Expand Down Expand Up @@ -129,10 +139,20 @@ public class MyClass2
.WithMessage("DependsOn Conflicts: MyClass1.Test > MyClass2.Test > MyClass1.Test")
.WithLocation(0),

Verifier
.Diagnostic(Rules.DependsOnConflicts)
.WithMessage("DependsOn Conflicts: MyClass1.Test2 > MyClass2.Test > MyClass1.Test > MyClass2.Test2 > MyClass1.Test2")
.WithLocation(1),

Verifier
.Diagnostic(Rules.DependsOnConflicts)
.WithMessage("DependsOn Conflicts: MyClass2.Test > MyClass1.Test > MyClass2.Test")
.WithLocation(2)
.WithLocation(2),

Verifier
.Diagnostic(Rules.DependsOnConflicts)
.WithMessage("DependsOn Conflicts: MyClass2.Test2 > MyClass1.Test > MyClass2.Test > MyClass1.Test2 > MyClass2.Test2")
.WithLocation(3)
);
}

Expand Down
104 changes: 49 additions & 55 deletions TUnit.Analyzers/DependsOnConflictAnalyzer.cs
Original file line number Diff line number Diff line change
@@ -1,32 +1,11 @@
using System.Collections.Immutable;
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Diagnostics;
using TUnit.Analyzers.Extensions;
using TUnit.Analyzers.Helpers;

namespace TUnit.Analyzers;

public record Chain(IMethodSymbol OriginalMethod)
{
public List<IMethodSymbol> Dependencies { get; } = new List<IMethodSymbol>();

public bool MethodTraversed(IMethodSymbol method) => Dependencies.Contains(method, SymbolEqualityComparer.Default);

public bool Any() => Dependencies.Any();

public void Add(IMethodSymbol dependency)
{
Dependencies.Add(dependency);
}

public IMethodSymbol[] GetCompleteChain()
{
return new[] { OriginalMethod }
.Concat(Dependencies.TakeUntil(d => SymbolEqualityComparer.Default.Equals(d, OriginalMethod)))
.ToArray();
}
}

[DiagnosticAnalyzer(LanguageNames.CSharp)]
public class DependsOnConflictAnalyzer : ConcurrentDiagnosticAnalyzer
{
Expand All @@ -44,16 +23,23 @@ private void AnalyzeSymbol(SymbolAnalysisContext context)

var dependsOnAttributes = GetDependsOnAttributes(method).Concat(GetDependsOnAttributes(method.ReceiverType ?? method.ContainingType)).ToArray();

var dependencies = GetDependencies(context, new Chain(method), method, dependsOnAttributes);

if (!dependencies.Any() || !dependencies.MethodTraversed(method))
if (dependsOnAttributes.Length == 0)
{
return;
}

context.ReportDiagnostic(Diagnostic.Create(Rules.DependsOnConflicts,
method.Locations.FirstOrDefault(),
string.Join(" > ", dependencies.GetCompleteChain().Select(x => $"{(x.ReceiverType ?? x.ContainingType).Name}.{x.Name}"))));
var cyclePath = new List<IMethodSymbol>();
var visited = new HashSet<IMethodSymbol>(SymbolEqualityComparer.Default);

if (FindCycleBackTo(context, method, method, dependsOnAttributes, cyclePath, visited))
{
var chainDescription = string.Join(" > ",
new[] { method }.Concat(cyclePath).Concat(new[] { method })
.Select(x => $"{(x.ReceiverType ?? x.ContainingType).Name}.{x.Name}"));

context.ReportDiagnostic(Diagnostic.Create(Rules.DependsOnConflicts,
method.Locations.FirstOrDefault(), chainDescription));
}
}

private AttributeData[] GetDependsOnAttributes(ISymbol methodSymbol)
Expand All @@ -65,22 +51,27 @@ private AttributeData[] GetDependsOnAttributes(ISymbol methodSymbol)
.ToArray();
}

private Chain GetDependencies(SymbolAnalysisContext context, Chain chain,
IMethodSymbol methodToGetDependenciesFor, AttributeData[] dependsOnAttributes)
/// <summary>
/// Performs a DFS from <paramref name="currentMethod"/> looking for a path back to <paramref name="targetMethod"/>.
/// Returns true if a cycle is found, with <paramref name="path"/> containing the intermediate methods.
/// </summary>
private bool FindCycleBackTo(SymbolAnalysisContext context, IMethodSymbol targetMethod,
IMethodSymbol currentMethod, AttributeData[] dependsOnAttributes,
List<IMethodSymbol> path, HashSet<IMethodSymbol> visited)
{
if (!methodToGetDependenciesFor.IsTestMethod(context.Compilation))
if (!currentMethod.IsTestMethod(context.Compilation))
{
return chain;
return false;
}

if (!dependsOnAttributes.Any())
if (dependsOnAttributes.Length == 0)
{
return chain;
return false;
}

foreach (var dependsOnAttribute in dependsOnAttributes)
{
var dependencyType = GetTypeContainingMethod(methodToGetDependenciesFor, dependsOnAttribute);
var dependencyType = GetTypeContainingMethod(currentMethod, dependsOnAttribute);

var dependencyMethodName = dependsOnAttribute.ConstructorArguments
.FirstOrNull(x => x.Kind == TypedConstantKind.Primitive)?.Value as string;
Expand All @@ -93,7 +84,7 @@ private Chain GetDependencies(SymbolAnalysisContext context, Chain chain,

if (dependencyType is not INamedTypeSymbol namedTypeSymbol)
{
return chain;
continue;
}

var methods = namedTypeSymbol
Expand All @@ -103,47 +94,50 @@ private Chain GetDependencies(SymbolAnalysisContext context, Chain chain,
.Where(x => x.MethodKind == MethodKind.Ordinary)
.ToArray();

if (!methods.Any())
if (methods.Length == 0)
{
context.ReportDiagnostic(Diagnostic.Create(Rules.NoMethodFound, dependsOnAttribute.GetLocation()));

return chain;
continue;
}

var foundDependencies = FilterMethods(dependencyMethodName, methods, dependencyParameterTypes);

if (!foundDependencies.Any())
if (foundDependencies.Length == 0)
{
context.ReportDiagnostic(Diagnostic.Create(Rules.NoMethodFound, dependsOnAttribute.GetLocation()));
return chain;
continue;
}

foreach (var foundDependency in foundDependencies)
{
if (chain.MethodTraversed(foundDependency))
// Found a cycle back to the target method
if (SymbolEqualityComparer.Default.Equals(foundDependency, targetMethod))
{
chain.Add(foundDependency);
return chain;
return true;
}

chain.Add(foundDependency);
// Skip already-visited methods to avoid infinite recursion
if (!visited.Add(foundDependency))
{
continue;
}

var nestedChain = GetDependencies(context, chain, foundDependency, GetDependsOnAttributes(foundDependency).Concat(GetDependsOnAttributes(foundDependency.ReceiverType ?? foundDependency.ContainingType)).ToArray());
path.Add(foundDependency);

foreach (var nestedDependency in nestedChain.Dependencies)
{
if (chain.MethodTraversed(nestedDependency))
{
chain.Add(nestedDependency);
return chain;
}
var nestedAttributes = GetDependsOnAttributes(foundDependency)
.Concat(GetDependsOnAttributes(foundDependency.ReceiverType ?? foundDependency.ContainingType))
.ToArray();

chain.Add(nestedDependency);
if (FindCycleBackTo(context, targetMethod, foundDependency, nestedAttributes, path, visited))
{
return true;
}

path.RemoveAt(path.Count - 1);
}
}

return chain;
return false;
}

private static ITypeSymbol GetTypeContainingMethod(IMethodSymbol methodToGetDependenciesFor, AttributeData dependsOnAttribute)
Expand Down
Loading