diff --git a/src/Analyzers/Activities/MatchingInputOutputTypeActivityAnalyzer.cs b/src/Analyzers/Activities/MatchingInputOutputTypeActivityAnalyzer.cs index d749df24d..ebdaac79a 100644 --- a/src/Analyzers/Activities/MatchingInputOutputTypeActivityAnalyzer.cs +++ b/src/Analyzers/Activities/MatchingInputOutputTypeActivityAnalyzer.cs @@ -75,6 +75,9 @@ public override void Initialize(AnalysisContext context) IMethodSymbol taskActivityRunAsync = knownSymbols.TaskActivityBase.GetMembers("RunAsync").OfType().Single(); INamedTypeSymbol voidSymbol = context.Compilation.GetSpecialType(SpecialType.System_Void); + // Get common DI types that should not be treated as activity input + INamedTypeSymbol? functionContextSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Azure.Functions.Worker.FunctionContext"); + // Search for Activity invocations ConcurrentBag invocations = []; context.RegisterOperationAction( @@ -161,6 +164,12 @@ public override void Initialize(AnalysisContext context) return; } + // If the parameter is FunctionContext, skip validation for this activity (it's a DI parameter, not real input) + if (functionContextSymbol != null && SymbolEqualityComparer.Default.Equals(inputParam.Type, functionContextSymbol)) + { + return; + } + ITypeSymbol? inputType = inputParam.Type; ITypeSymbol? outputType = methodSymbol.ReturnType; @@ -306,7 +315,8 @@ public override void Initialize(AnalysisContext context) continue; } - if (!SymbolEqualityComparer.Default.Equals(invocation.InputType, activity.InputType)) + // Check input type compatibility + if (!AreTypesCompatible(ctx.Compilation, invocation.InputType, activity.InputType)) { string actual = invocation.InputType?.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat) ?? "none"; string expected = activity.InputType?.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat) ?? "none"; @@ -316,7 +326,8 @@ public override void Initialize(AnalysisContext context) ctx.ReportDiagnostic(diagnostic); } - if (!SymbolEqualityComparer.Default.Equals(invocation.OutputType, activity.OutputType)) + // Check output type compatibility + if (!AreTypesCompatible(ctx.Compilation, activity.OutputType, invocation.OutputType)) { string actual = invocation.OutputType?.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat) ?? "none"; string expected = activity.OutputType?.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat) ?? "none"; @@ -330,6 +341,148 @@ public override void Initialize(AnalysisContext context) }); } + /// + /// Checks if the source type is compatible with (can be assigned to) the target type. + /// This handles polymorphism, interface implementation, inheritance, and collection type compatibility. + /// + static bool AreTypesCompatible(Compilation compilation, ITypeSymbol? sourceType, ITypeSymbol? targetType) + { + // Both null = compatible (no input/output on both sides) + if (sourceType == null && targetType == null) + { + return true; + } + + // Special case: null (no input/output provided) can be passed to explicitly nullable parameters + // This handles nullable value types (int?) and nullable reference types (string?) + if (sourceType == null && targetType != null) + { + // Check if target is a nullable value type (Nullable) + if (targetType.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T) + { + return true; + } + + // Check if target is a nullable reference type (string?) + if (targetType.NullableAnnotation == NullableAnnotation.Annotated) + { + return true; + } + + // Not nullable, so null input is incompatible + return false; + } + + // If targetType is null but sourceType is not, they're incompatible + if (targetType == null && sourceType != null) + { + return false; + } + + // Check if types are exactly equal + if (SymbolEqualityComparer.Default.Equals(sourceType, targetType)) + { + return true; + } + + // Check if source type can be converted to target type (handles inheritance, interface implementation, etc.) + // At this point, both sourceType and targetType are guaranteed to be non-null + Conversion conversion = compilation.ClassifyConversion(sourceType!, targetType!); + if (conversion.IsImplicit || conversion.IsIdentity) + { + return true; + } + + // Special handling for collection types since ClassifyConversion doesn't always recognize + // generic interface implementations (e.g., List to IReadOnlyList) + // At this point, both sourceType and targetType are guaranteed to be non-null + if (IsCollectionTypeCompatible(sourceType!, targetType!)) + { + return true; + } + + return false; + } + + /// + /// Checks if the source collection type is compatible with the target collection type. + /// Handles common scenarios like List to IReadOnlyList, arrays to IEnumerable, etc. + /// + static bool IsCollectionTypeCompatible(ITypeSymbol sourceType, ITypeSymbol targetType) + { + // Check if source is an array and target is a collection interface + if (sourceType is IArrayTypeSymbol sourceArray && targetType is INamedTypeSymbol targetNamed) + { + return IsArrayCompatibleWithCollectionInterface(sourceArray, targetNamed); + } + + // Both must be generic named types + if (sourceType is not INamedTypeSymbol sourceNamed || targetType is not INamedTypeSymbol targetNamedType) + { + return false; + } + + // Both must be generic types with the same type arguments + if (!sourceNamed.IsGenericType || !targetNamedType.IsGenericType) + { + return false; + } + + if (sourceNamed.TypeArguments.Length != targetNamedType.TypeArguments.Length) + { + return false; + } + + // Check if type arguments are compatible (could be different but compatible types) + for (int i = 0; i < sourceNamed.TypeArguments.Length; i++) + { + if (!SymbolEqualityComparer.Default.Equals(sourceNamed.TypeArguments[i], targetNamedType.TypeArguments[i])) + { + // Type arguments must match exactly for collections (we don't support covariance/contravariance here) + return false; + } + } + + // Check if source type implements or derives from target type + // This handles: List → IReadOnlyList, List → IEnumerable, etc. + return ImplementsInterface(sourceNamed, targetNamedType); + } + + /// + /// Checks if an array type is compatible with a collection interface. + /// + static bool IsArrayCompatibleWithCollectionInterface(IArrayTypeSymbol arrayType, INamedTypeSymbol targetInterface) + { + if (!targetInterface.IsGenericType || targetInterface.TypeArguments.Length != 1) + { + return false; + } + + // Check if array element type matches the generic type argument + if (!SymbolEqualityComparer.Default.Equals(arrayType.ElementType, targetInterface.TypeArguments[0])) + { + return false; + } + + // Array implements: IEnumerable, ICollection, IList, IReadOnlyCollection, IReadOnlyList + string targetName = targetInterface.OriginalDefinition.ToDisplayString(); + return targetName == "System.Collections.Generic.IEnumerable" || + targetName == "System.Collections.Generic.ICollection" || + targetName == "System.Collections.Generic.IList" || + targetName == "System.Collections.Generic.IReadOnlyCollection" || + targetName == "System.Collections.Generic.IReadOnlyList"; + } + + /// + /// Checks if the source type implements the target interface. + /// + static bool ImplementsInterface(INamedTypeSymbol sourceType, INamedTypeSymbol targetInterface) + { + // Check all interfaces implemented by the source type + return sourceType.AllInterfaces.Any(@interface => + SymbolEqualityComparer.Default.Equals(@interface.OriginalDefinition, targetInterface.OriginalDefinition)); + } + struct ActivityInvocation { public string Name { get; set; } diff --git a/test/Analyzers.Tests/Activities/MatchingInputOutputTypeActivityAnalyzerTests.cs b/test/Analyzers.Tests/Activities/MatchingInputOutputTypeActivityAnalyzerTests.cs index 2970f6adf..876027cc7 100644 --- a/test/Analyzers.Tests/Activities/MatchingInputOutputTypeActivityAnalyzerTests.cs +++ b/test/Analyzers.Tests/Activities/MatchingInputOutputTypeActivityAnalyzerTests.cs @@ -406,7 +406,6 @@ async Task Method(TaskOrchestrationContext context) await VerifyCS.VerifyDurableTaskAnalyzerAsync(code); } - static DiagnosticResult BuildInputDiagnostic() { return VerifyCS.Diagnostic(MatchingInputOutputTypeActivityAnalyzer.InputArgumentTypeMismatchDiagnosticId);