diff --git a/src/Meziantou.Analyzer/Internals/TypeSymbolExtensions.cs b/src/Meziantou.Analyzer/Internals/TypeSymbolExtensions.cs index f26fb8de..f75bdf55 100755 --- a/src/Meziantou.Analyzer/Internals/TypeSymbolExtensions.cs +++ b/src/Meziantou.Analyzer/Internals/TypeSymbolExtensions.cs @@ -24,10 +24,23 @@ public static IList GetAllInterfacesIncludingThis(this ITypeSy } public static bool InheritsFrom(this ITypeSymbol classSymbol, [NotNullWhen(true)] ITypeSymbol? baseClassType) + { + return InheritsFrom(classSymbol, baseClassType, visitedTypeParameters: null); + } + + private static bool InheritsFrom(this ITypeSymbol classSymbol, [NotNullWhen(true)] ITypeSymbol? baseClassType, HashSet? visitedTypeParameters) { if (baseClassType is null) return false; + if (classSymbol is ITypeParameterSymbol typeParameter) + { + return AnyConstraintTypeMatches(typeParameter, visitedTypeParameters, (constraintType, visitedTypeParameters) => + { + return !constraintType.IsEqualTo(baseClassType) && constraintType.InheritsFrom(baseClassType, visitedTypeParameters); + }); + } + var baseType = classSymbol.BaseType; while (baseType is not null) { @@ -41,10 +54,23 @@ public static bool InheritsFrom(this ITypeSymbol classSymbol, [NotNullWhen(true) } public static bool Implements(this ITypeSymbol classSymbol, [NotNullWhen(true)] ITypeSymbol? interfaceType) + { + return Implements(classSymbol, interfaceType, visitedTypeParameters: null); + } + + private static bool Implements(this ITypeSymbol classSymbol, [NotNullWhen(true)] ITypeSymbol? interfaceType, HashSet? visitedTypeParameters) { if (interfaceType is null) return false; + if (classSymbol is ITypeParameterSymbol typeParameter) + { + return AnyConstraintTypeMatches(typeParameter, visitedTypeParameters, (constraintType, visitedTypeParameters) => + { + return constraintType.IsEqualTo(interfaceType) || constraintType.Implements(interfaceType, visitedTypeParameters); + }); + } + foreach (var @interface in classSymbol.AllInterfaces) { if (@interface.IsEqualTo(interfaceType)) @@ -55,10 +81,23 @@ public static bool Implements(this ITypeSymbol classSymbol, [NotNullWhen(true)] } public static bool ImplementsGenericInterface(this ITypeSymbol classSymbol, [NotNullWhen(true)] ITypeSymbol? interfaceType) + { + return ImplementsGenericInterface(classSymbol, interfaceType, visitedTypeParameters: null); + } + + private static bool ImplementsGenericInterface(this ITypeSymbol classSymbol, [NotNullWhen(true)] ITypeSymbol? interfaceType, HashSet? visitedTypeParameters) { if (interfaceType is null) return false; + if (classSymbol is ITypeParameterSymbol typeParameter) + { + return AnyConstraintTypeMatches(typeParameter, visitedTypeParameters, (constraintType, visitedTypeParameters) => + { + return constraintType.OriginalDefinition.IsEqualTo(interfaceType.OriginalDefinition) || constraintType.ImplementsGenericInterface(interfaceType, visitedTypeParameters); + }); + } + foreach (var iface in classSymbol.AllInterfaces) { if (iface.OriginalDefinition.IsEqualTo(interfaceType.OriginalDefinition)) @@ -73,16 +112,7 @@ public static bool IsOrImplements(this ITypeSymbol symbol, [NotNullWhen(true)] I if (interfaceType is null) return false; - if (symbol is INamedTypeSymbol { TypeKind: TypeKind.Interface } interfaceSymbol && interfaceSymbol.IsEqualTo(interfaceType)) - return true; - - foreach (var @interface in symbol.AllInterfaces) - { - if (@interface.IsEqualTo(interfaceType)) - return true; - } - - return false; + return symbol.IsEqualTo(interfaceType) || symbol.Implements(interfaceType); } public static IEnumerable GetAttributes(this ISymbol symbol, ITypeSymbol? attributeType, bool inherits = true) @@ -148,11 +178,42 @@ public static bool HasAttribute(this ISymbol symbol, [NotNullWhen(true)] ITypeSy } public static bool IsOrInheritFrom(this ITypeSymbol symbol, [NotNullWhen(true)] ITypeSymbol? expectedType) + { + return IsOrInheritFrom(symbol, expectedType, visitedTypeParameters: null); + } + + private static bool IsOrInheritFrom(this ITypeSymbol symbol, [NotNullWhen(true)] ITypeSymbol? expectedType, HashSet? visitedTypeParameters) { if (expectedType is null) return false; - return symbol.IsEqualTo(expectedType) || (!expectedType.IsSealed && symbol.InheritsFrom(expectedType)); + if (symbol.IsEqualTo(expectedType)) + return true; + + if (symbol is ITypeParameterSymbol typeParameter) + { + return AnyConstraintTypeMatches(typeParameter, visitedTypeParameters, (constraintType, visitedTypeParameters) => + { + return constraintType.IsOrInheritFrom(expectedType, visitedTypeParameters); + }); + } + + return !expectedType.IsSealed && symbol.InheritsFrom(expectedType, visitedTypeParameters); + } + + private static bool AnyConstraintTypeMatches(ITypeParameterSymbol typeParameter, HashSet? visitedTypeParameters, Func, bool> predicate) + { + visitedTypeParameters ??= []; + if (!visitedTypeParameters.Add(typeParameter)) + return false; + + foreach (var constraintType in typeParameter.ConstraintTypes) + { + if (predicate(constraintType, visitedTypeParameters)) + return true; + } + + return false; } public static bool IsEqualToAny([NotNullWhen(true)] this ITypeSymbol? symbol, params ReadOnlySpan expectedTypes) diff --git a/src/Meziantou.Analyzer/Rules/UseEventHandlerOfTAnalyzer.cs b/src/Meziantou.Analyzer/Rules/UseEventHandlerOfTAnalyzer.cs index a59fdfeb..ba9ee1b6 100644 --- a/src/Meziantou.Analyzer/Rules/UseEventHandlerOfTAnalyzer.cs +++ b/src/Meziantou.Analyzer/Rules/UseEventHandlerOfTAnalyzer.cs @@ -90,19 +90,7 @@ private bool IsValidSignature(IMethodSymbol methodSymbol, [NotNullWhen(false)] o private bool IsEventArgsType(ITypeSymbol type) { - if (type.IsOrInheritFrom(EventArgsSymbol)) - return true; - - if (type is not ITypeParameterSymbol typeParameter) - return false; - - foreach (var constraintType in typeParameter.ConstraintTypes) - { - if (IsEventArgsType(constraintType)) - return true; - } - - return false; + return type.IsOrInheritFrom(EventArgsSymbol); } } } diff --git a/tests/Meziantou.Analyzer.Test/Rules/EventsShouldHaveProperArgumentsAnalyzerTests.cs b/tests/Meziantou.Analyzer.Test/Rules/EventsShouldHaveProperArgumentsAnalyzerTests.cs index b7e10d58..196095d5 100755 --- a/tests/Meziantou.Analyzer.Test/Rules/EventsShouldHaveProperArgumentsAnalyzerTests.cs +++ b/tests/Meziantou.Analyzer.Test/Rules/EventsShouldHaveProperArgumentsAnalyzerTests.cs @@ -294,4 +294,25 @@ await CreateProjectBuilder() .ShouldFixCodeWith(Fix) .ValidateAsync(); } + + [Fact] + public async Task InvalidEventArgs_GenericTypeParameterConstraint() + { + const string SourceCode = """ + using System; + delegate void CustomEventHandler(object sender, TEventArgs e) where TEventArgs : EventArgs; + class Test where TEventArgs : EventArgs + { + public event CustomEventHandler MyEvent; + + void OnEvent() + { + MyEvent?.Invoke(this, [|null|]); + } + } + """; + await CreateProjectBuilder() + .WithSourceCode(SourceCode) + .ValidateAsync(); + } }