Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 72 additions & 11 deletions src/Meziantou.Analyzer/Internals/TypeSymbolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,23 @@ public static IList<INamedTypeSymbol> GetAllInterfacesIncludingThis(this ITypeSy
}

public static bool InheritsFrom(this ITypeSymbol classSymbol, [NotNullWhen(true)] ITypeSymbol? baseClassType)
{
return InheritsFrom(classSymbol, baseClassType, visitedTypeParameters: null);
}

private static bool InheritsFrom(this ITypeSymbol classSymbol, [NotNullWhen(true)] ITypeSymbol? baseClassType, HashSet<ITypeParameterSymbol>? visitedTypeParameters)
{
if (baseClassType is null)
return false;

if (classSymbol is ITypeParameterSymbol typeParameter)
{
return AnyConstraintTypeMatches(typeParameter, visitedTypeParameters, (constraintType, visitedTypeParameters) =>
{
return !constraintType.IsEqualTo(baseClassType) && constraintType.InheritsFrom(baseClassType, visitedTypeParameters);
});
}

var baseType = classSymbol.BaseType;
while (baseType is not null)
{
Expand All @@ -41,10 +54,23 @@ public static bool InheritsFrom(this ITypeSymbol classSymbol, [NotNullWhen(true)
}

public static bool Implements(this ITypeSymbol classSymbol, [NotNullWhen(true)] ITypeSymbol? interfaceType)
{
return Implements(classSymbol, interfaceType, visitedTypeParameters: null);
}

private static bool Implements(this ITypeSymbol classSymbol, [NotNullWhen(true)] ITypeSymbol? interfaceType, HashSet<ITypeParameterSymbol>? visitedTypeParameters)
{
if (interfaceType is null)
return false;

if (classSymbol is ITypeParameterSymbol typeParameter)
{
return AnyConstraintTypeMatches(typeParameter, visitedTypeParameters, (constraintType, visitedTypeParameters) =>
{
return constraintType.IsEqualTo(interfaceType) || constraintType.Implements(interfaceType, visitedTypeParameters);
});
}

foreach (var @interface in classSymbol.AllInterfaces)
{
if (@interface.IsEqualTo(interfaceType))
Expand All @@ -55,10 +81,23 @@ public static bool Implements(this ITypeSymbol classSymbol, [NotNullWhen(true)]
}

public static bool ImplementsGenericInterface(this ITypeSymbol classSymbol, [NotNullWhen(true)] ITypeSymbol? interfaceType)
{
return ImplementsGenericInterface(classSymbol, interfaceType, visitedTypeParameters: null);
}

private static bool ImplementsGenericInterface(this ITypeSymbol classSymbol, [NotNullWhen(true)] ITypeSymbol? interfaceType, HashSet<ITypeParameterSymbol>? visitedTypeParameters)
{
if (interfaceType is null)
return false;

if (classSymbol is ITypeParameterSymbol typeParameter)
{
return AnyConstraintTypeMatches(typeParameter, visitedTypeParameters, (constraintType, visitedTypeParameters) =>
{
return constraintType.OriginalDefinition.IsEqualTo(interfaceType.OriginalDefinition) || constraintType.ImplementsGenericInterface(interfaceType, visitedTypeParameters);
});
}

foreach (var iface in classSymbol.AllInterfaces)
{
if (iface.OriginalDefinition.IsEqualTo(interfaceType.OriginalDefinition))
Expand All @@ -73,16 +112,7 @@ public static bool IsOrImplements(this ITypeSymbol symbol, [NotNullWhen(true)] I
if (interfaceType is null)
return false;

if (symbol is INamedTypeSymbol { TypeKind: TypeKind.Interface } interfaceSymbol && interfaceSymbol.IsEqualTo(interfaceType))
return true;

foreach (var @interface in symbol.AllInterfaces)
{
if (@interface.IsEqualTo(interfaceType))
return true;
}

return false;
return symbol.IsEqualTo(interfaceType) || symbol.Implements(interfaceType);
}

public static IEnumerable<AttributeData> GetAttributes(this ISymbol symbol, ITypeSymbol? attributeType, bool inherits = true)
Expand Down Expand Up @@ -148,11 +178,42 @@ public static bool HasAttribute(this ISymbol symbol, [NotNullWhen(true)] ITypeSy
}

public static bool IsOrInheritFrom(this ITypeSymbol symbol, [NotNullWhen(true)] ITypeSymbol? expectedType)
{
return IsOrInheritFrom(symbol, expectedType, visitedTypeParameters: null);
}

private static bool IsOrInheritFrom(this ITypeSymbol symbol, [NotNullWhen(true)] ITypeSymbol? expectedType, HashSet<ITypeParameterSymbol>? visitedTypeParameters)
{
if (expectedType is null)
return false;

return symbol.IsEqualTo(expectedType) || (!expectedType.IsSealed && symbol.InheritsFrom(expectedType));
if (symbol.IsEqualTo(expectedType))
return true;

if (symbol is ITypeParameterSymbol typeParameter)
{
return AnyConstraintTypeMatches(typeParameter, visitedTypeParameters, (constraintType, visitedTypeParameters) =>
{
return constraintType.IsOrInheritFrom(expectedType, visitedTypeParameters);
});
}

return !expectedType.IsSealed && symbol.InheritsFrom(expectedType, visitedTypeParameters);
}

private static bool AnyConstraintTypeMatches(ITypeParameterSymbol typeParameter, HashSet<ITypeParameterSymbol>? visitedTypeParameters, Func<ITypeSymbol, HashSet<ITypeParameterSymbol>, bool> predicate)
{
visitedTypeParameters ??= [];
if (!visitedTypeParameters.Add(typeParameter))
return false;

foreach (var constraintType in typeParameter.ConstraintTypes)
{
if (predicate(constraintType, visitedTypeParameters))
return true;
}

return false;
}

public static bool IsEqualToAny([NotNullWhen(true)] this ITypeSymbol? symbol, params ReadOnlySpan<ITypeSymbol?> expectedTypes)
Expand Down
14 changes: 1 addition & 13 deletions src/Meziantou.Analyzer/Rules/UseEventHandlerOfTAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,7 @@ private bool IsValidSignature(IMethodSymbol methodSymbol, [NotNullWhen(false)] o

private bool IsEventArgsType(ITypeSymbol type)
{
if (type.IsOrInheritFrom(EventArgsSymbol))
return true;

if (type is not ITypeParameterSymbol typeParameter)
return false;

foreach (var constraintType in typeParameter.ConstraintTypes)
{
if (IsEventArgsType(constraintType))
return true;
}

return false;
return type.IsOrInheritFrom(EventArgsSymbol);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -294,4 +294,25 @@ await CreateProjectBuilder()
.ShouldFixCodeWith(Fix)
.ValidateAsync();
}

[Fact]
public async Task InvalidEventArgs_GenericTypeParameterConstraint()
{
const string SourceCode = """
using System;
delegate void CustomEventHandler<TEventArgs>(object sender, TEventArgs e) where TEventArgs : EventArgs;
class Test<TEventArgs> where TEventArgs : EventArgs
{
public event CustomEventHandler<TEventArgs> MyEvent;

void OnEvent()
{
MyEvent?.Invoke(this, [|null|]);
}
}
""";
await CreateProjectBuilder()
.WithSourceCode(SourceCode)
.ValidateAsync();
}
}